Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
752c6ade
Unverified
Commit
752c6ade
authored
Jul 19, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 19, 2025
Browse files
[V0 Deprecation] Deprecate BlockSparse Attention & Phi3-Small (#21217)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
881e3cbe
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
41 additions
and
1405 deletions
+41
-1405
.buildkite/scripts/hardware_ci/run-amd-test.sh
.buildkite/scripts/hardware_ci/run-amd-test.sh
+0
-1
docs/models/supported_models.md
docs/models/supported_models.md
+0
-1
tests/kernels/attention/test_blocksparse_attention.py
tests/kernels/attention/test_blocksparse_attention.py
+0
-441
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+24
-8
tests/models/registry.py
tests/models/registry.py
+0
-4
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+0
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+0
-466
vllm/attention/backends/differential_flash_attn.py
vllm/attention/backends/differential_flash_attn.py
+0
-4
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+0
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+1
-5
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+0
-1
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+4
-8
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+0
-1
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+4
-8
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-5
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+4
-8
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+1
-5
vllm/attention/layer.py
vllm/attention/layer.py
+2
-4
vllm/attention/ops/blocksparse_attention/__init__.py
vllm/attention/ops/blocksparse_attention/__init__.py
+0
-0
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
...ops/blocksparse_attention/blocksparse_attention_kernel.py
+0
-433
No files found.
.buildkite/scripts/hardware_ci/run-amd-test.sh
View file @
752c6ade
...
...
@@ -108,7 +108,6 @@ fi
if
[[
$commands
==
*
" kernels/attention"
*
]]
;
then
commands
=
"
${
commands
}
\
--ignore=kernels/attention/test_attention_selector.py
\
--ignore=kernels/attention/test_blocksparse_attention.py
\
--ignore=kernels/attention/test_encoder_decoder_attn.py
\
--ignore=kernels/attention/test_flash_attn.py
\
--ignore=kernels/attention/test_flashinfer.py
\
...
...
docs/models/supported_models.md
View file @
752c6ade
...
...
@@ -376,7 +376,6 @@ Specified using `--task generate`.
|
`OrionForCausalLM`
| Orion |
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc. | | ✅︎ | ✅︎ |
|
`PhiForCausalLM`
| Phi |
`microsoft/phi-1_5`
,
`microsoft/phi-2`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Phi3ForCausalLM`
| Phi-4, Phi-3 |
`microsoft/Phi-4-mini-instruct`
,
`microsoft/Phi-4`
,
`microsoft/Phi-3-mini-4k-instruct`
,
`microsoft/Phi-3-mini-128k-instruct`
,
`microsoft/Phi-3-medium-128k-instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Phi3SmallForCausalLM`
| Phi-3-Small |
`microsoft/Phi-3-small-8k-instruct`
,
`microsoft/Phi-3-small-128k-instruct`
, etc. | | ✅︎ | ✅︎ |
|
`PhiMoEForCausalLM`
| Phi-3.5-MoE |
`microsoft/Phi-3.5-MoE-instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Phi4FlashForCausalLM`
| Phi-4-mini-flash-reasoning |
`microsoft/microsoft/Phi-4-mini-instruct`
, etc. | | | |
|
`PersimmonForCausalLM`
| Persimmon |
`adept/persimmon-8b-base`
,
`adept/persimmon-8b-chat`
, etc. | | ✅︎ | ✅︎ |
...
...
tests/kernels/attention/test_blocksparse_attention.py
deleted
100644 → 0
View file @
881e3cbe
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
typing
import
Optional
import
pytest
import
torch
from
tests.kernels.allclose_default
import
get_default_atol
,
get_default_rtol
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_max_shared_memory_bytes
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
# MAX_SEQ_LEN = 2771
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS
=
4321
# Arbitrary values for testing
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
NUM_GEN_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
112
]
BLOCK_SIZES
=
[
16
]
USE_ALIBI
=
[
False
,
True
]
KV_CACHE_DTYPE
=
[
"auto"
,
"fp8"
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
'cuda:0'
]
BLOCKSPARSE_LOCAL_BLOCKS
=
[
16
]
BLOCKSPARSE_VERT_STRIDES
=
[
8
]
BLOCKSPARSE_BLOCK_SIZES
=
[
64
]
BLOCKSPARSE_HEADS_SLIDINGS
=
[
2
,
-
1
]
BLOCKSPARSE_HOMO_HEADS
=
[
True
,
False
]
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
def
ref_single_query_cached_kv_attention
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
num_queries_per_kv
:
int
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
1
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
num_query_heads
=
query
.
shape
[
1
]
num_kv_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
block_tables_lst
=
block_tables
.
cpu
().
tolist
()
seq_lens_lst
=
seq_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
list
[
torch
.
Tensor
]
=
[]
values_lst
:
list
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_kv_heads
,
head_size
)
keys_lst
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
values
=
torch
.
repeat_interleave
(
values
,
num_queries_per_kv
,
dim
=
1
)
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
# Create the ALiBi bias used in the paged attention kernel.
position_ids
=
torch
.
arange
(
seq_len
).
int
()
alibi_bias
=
(
position_ids
-
seq_len
+
1
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
if
blocksparse_vert_stride
>=
1
:
bsize
=
blocksparse_block_size
hsliding
=
blocksparse_head_sliding_step
vert
=
blocksparse_vert_stride
locals
=
blocksparse_local_blocks
qb
=
(
seq_len
-
1
)
//
bsize
attn_mask
=
q
.
new_zeros
(
(
num_query_heads
,
1
,
seq_len
)).
float
()
-
torch
.
inf
for
h
in
range
(
num_query_heads
):
if
hsliding
>=
0
:
# slide with q heads
bs_offset
=
(
tp_rank
*
num_query_heads
+
h
)
*
hsliding
+
1
else
:
# slide with kv heads
bs_offset
=
(
tp_rank
*
num_kv_heads
+
h
//
num_queries_per_kv
)
*
(
-
hsliding
)
+
1
for
kb
in
range
(
qb
+
1
):
kj
=
kb
*
bsize
if
(
qb
-
kb
)
<
locals
or
\
(
kb
+
bs_offset
)
%
vert
==
0
:
attn_mask
[
h
,
0
,
kj
:
min
(
kj
+
bsize
,
seq_len
)]
=
0
if
alibi_bias
is
not
None
:
attn_mask
+=
alibi_bias
else
:
attn_mask
=
alibi_bias
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
,
attn_mask
=
attn_mask
)
out
=
out
.
view
(
num_query_heads
,
head_size
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
USE_ALIBI
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_local_blocks"
,
BLOCKSPARSE_LOCAL_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_vert_stride"
,
BLOCKSPARSE_VERT_STRIDES
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_block_size"
,
BLOCKSPARSE_BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_head_sliding_step"
,
BLOCKSPARSE_HEADS_SLIDINGS
)
def
test_paged_attention
(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
seed
:
int
,
device
:
str
,
blocksparse_local_blocks
:
int
,
blocksparse_vert_stride
:
int
,
blocksparse_block_size
:
int
,
blocksparse_head_sliding_step
:
int
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
rand
(
num_query_heads
,
dtype
=
torch
.
float
)
seq_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
seq_lens
[
-
1
]
=
MAX_SEQ_LEN
max_seq_len
=
max
(
seq_lens
)
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
seed
,
device
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
tp_rank
=
0
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v1"
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_block_size
=
blocksparse_block_size
,
blocksparse_head_sliding_step
=
blocksparse_head_sliding_step
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_block_size
=
blocksparse_block_size
,
blocksparse_head_sliding_step
=
blocksparse_head_sliding_step
,
)
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
# Run the reference implementation.
if
kv_cache_dtype
==
"fp8"
:
# Convert cache data back to dtype.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
//
x
,
block_size
,
x
)
dequantized_key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
device
)
ops
.
convert_fp8
(
dequantized_key_cache
,
key_cache
)
key_cache
=
dequantized_key_cache
value_cache_shape
=
value_cache
.
shape
dequantized_value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
device
)
ops
.
convert_fp8
(
dequantized_value_cache
,
value_cache
)
value_cache
=
dequantized_value_cache
ref_output
=
torch
.
empty_like
(
query
)
ref_single_query_cached_kv_attention
(
ref_output
,
query
,
num_queries_per_kv
,
key_cache
,
value_cache
,
block_tables
,
seq_lens
,
scale
,
alibi_slopes
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol
,
rtol
=
1e-3
,
1e-5
if
kv_cache_dtype
==
"fp8"
:
atol
,
rtol
=
1e-2
,
1e-5
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
list
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_local_blocks"
,
BLOCKSPARSE_LOCAL_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_vert_stride"
,
BLOCKSPARSE_VERT_STRIDES
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_block_size"
,
BLOCKSPARSE_BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"blocksparse_homo_heads"
,
BLOCKSPARSE_HOMO_HEADS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_varlen_blocksparse_attention_prefill
(
num_seqs
:
int
,
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
blocksparse_local_blocks
:
int
,
blocksparse_vert_stride
:
int
,
blocksparse_block_size
:
int
,
blocksparse_homo_heads
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len
=
min
(
MAX_SEQ_LEN
,
4096
)
seq_lens
=
random
.
sample
(
range
(
1
,
max_len
),
num_seqs
)
cu_seq_lens
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_lens
),
dim
=
0
)
num_tokens
=
sum
(
seq_lens
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
qkv
=
torch
.
empty
(
num_tokens
,
num_query_heads
+
2
*
num_kv_heads
,
head_size
,
dtype
=
dtype
)
qkv
.
uniform_
(
-
scale
,
scale
)
query
,
key
,
value
=
qkv
.
split
(
[
num_query_heads
,
num_kv_heads
,
num_kv_heads
],
dim
=
1
)
bs_attn_op
=
LocalStridedBlockSparseAttn
(
num_query_heads
,
max_len
,
local_blocks
=
blocksparse_local_blocks
,
vert_stride
=
blocksparse_vert_stride
,
block_size
=
blocksparse_block_size
,
device
=
device
,
dtype
=
dtype
,
homo_head
=
blocksparse_homo_heads
)
output
=
bs_attn_op
(
query
,
key
,
value
,
cu_seq_lens
.
to
(
device
),
sm_scale
=
scale
)
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
ref_output
=
ref_multi_query_kv_attention
(
cu_seq_lens
.
tolist
(),
query
,
key
,
value
,
scale
,
dtype
,
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/attention/test_rocm_attention_selector.py
View file @
752c6ade
...
...
@@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# change the attention backend to triton MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
...
...
@@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# If use_mla is true
# The selected backend is triton MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# change the attention backend to AITER MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
False
,
True
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
use_mla
=
True
)
assert
(
backend
.
get_name
()
==
"ROCM_AITER_MLA"
or
backend
.
get_name
()
==
"ROCM_AITER_MLA_VLLM_V1"
)
...
...
@@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# The selected backend is ROCM_AITER_MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
False
,
True
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
1
,
False
,
use_mla
=
True
)
assert
(
backend
.
get_name
()
==
"ROCM_AITER_MLA"
or
backend
.
get_name
()
==
"ROCM_AITER_MLA_VLLM_V1"
)
tests/models/registry.py
View file @
752c6ade
...
...
@@ -247,10 +247,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM"
:
_HfExamplesInfo
(
"adept/persimmon-8b-chat"
),
"PhiForCausalLM"
:
_HfExamplesInfo
(
"microsoft/phi-2"
),
"Phi3ForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-mini-4k-instruct"
),
# Blocksparse attention not supported in V1 yet
"Phi3SmallForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-small-8k-instruct"
,
trust_remote_code
=
True
,
v0_only
=
True
),
"Phi4FlashForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-4-mini-flash-reasoning"
,
# noqa: E501
trust_remote_code
=
True
,
v0_only
=
True
,
...
...
vllm/attention/backends/abstract.py
View file @
752c6ade
...
...
@@ -269,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]):
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
vllm/attention/backends/blocksparse_attn.py
deleted
100644 → 0
View file @
881e3cbe
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
@
dataclass
class
BlocksparseParams
:
max_seqlen
:
int
# Num q heads per tensor-parallel rank/partition
num_heads
:
int
# per TP partition
# Num kv heads per tensor-parallel rank/partition
num_kv_heads
:
int
# block size used for blocksparse attention.
# This is the block_size used in `local_blocks`, `vert_stride`.
block_size
:
int
# Number of blocks for local attention, i.e., number of
# local attended tokens / `sparse_block_size`
local_blocks
:
int
# Attend to one block per every `vert_stride` blocks.
# Controlling the sparsity
vert_stride
:
int
"""
If to use the same vertical stride offset for all heads,
i.e., attend to the same block of tokens on all heads.
By default, it is False, i.e., attention on the non-local
blocks depends on the `head_idx`, that is on
blocks satisfying
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
`block_idx = position_id // sparse_block_size`.
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
for more detail.
"""
homo_head
:
bool
=
False
# If within a group, the kv offsets that each q attends is the same or no.
homo_head_group
:
bool
=
False
# Decided by homo_head and homo_head group
head_sliding_step
:
int
=
field
(
init
=
False
)
# range of q heads to for a TP rank
active_head_range
:
Tuple
=
field
(
init
=
False
)
def
__post_init__
(
self
):
assert
self
.
block_size
>
0
assert
self
.
local_blocks
>=
0
assert
self
.
vert_stride
>=
1
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
total_heads
=
tp_size
*
self
.
num_heads
total_kv_heads
=
tp_size
*
self
.
num_kv_heads
if
self
.
homo_head
:
self
.
head_sliding_step
=
0
elif
self
.
homo_head_group
:
head_sliding_step
=
get_head_sliding_step
(
total_kv_heads
,
self
.
vert_stride
)
# negative indicates sliding along kv heads, i.e., homo q group
self
.
head_sliding_step
=
-
head_sliding_step
else
:
self
.
head_sliding_step
=
get_head_sliding_step
(
total_heads
,
self
.
vert_stride
)
self
.
active_head_range
=
(
tp_rank
*
self
.
num_heads
,
(
tp_rank
+
1
)
*
self
.
num_heads
,
)
class
BlocksparseFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"BLOCK_SPARSE_FLASH_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"BlocksparseFlashAttentionImpl"
]:
return
BlocksparseFlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
BlocksparseFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
BlocksparseFlashAttentionMetadata
(
AttentionMetadata
):
"""A copy of Metadata for FlashAttentionBackend,
to avoid having to install flash_attn.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Max number of query tokens for among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
assert
self
.
seq_start_loc
is
not
None
self
.
_cached_prefill_metadata
=
BlocksparseFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
BlocksparseFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
)
return
self
.
_cached_decode_metadata
class
BlocksparseFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
BlocksparseFlashAttentionMetadata
]):
_metadata_cls
=
BlocksparseFlashAttentionMetadata
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"BLOCK_SPARSE_FLASH_ATTN Backend."
)
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
assert
sliding_window
is
None
,
ValueError
(
"sliding_window is invalid for blocksparse attention."
)
assert
logits_soft_cap
is
None
,
ValueError
(
"logits_soft_cap is invalid for blocksparse attention."
)
if
"num_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_heads"
]
=
num_heads
if
"num_kv_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_kv_heads"
]
=
num_kv_heads
or
num_heads
self
.
blocksparse_params
=
BlocksparseParams
(
**
blocksparse_params
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
alibi_slopes
=
alibi_slopes
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
local_blocks
=
self
.
blocksparse_params
.
local_blocks
self
.
vert_stride
=
self
.
blocksparse_params
.
vert_stride
self
.
sparse_block_size
=
self
.
blocksparse_params
.
block_size
self
.
head_sliding_step
=
self
.
blocksparse_params
.
head_sliding_step
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
total_num_heads
=
num_heads
*
self
.
tp_size
self
.
bs_attn
=
LocalStridedBlockSparseAttn
(
total_num_heads
,
self
.
blocksparse_params
.
max_seqlen
,
self
.
blocksparse_params
.
local_blocks
,
self
.
blocksparse_params
.
vert_stride
,
self
.
blocksparse_params
.
block_size
,
homo_head
=
self
.
blocksparse_params
.
homo_head
,
active_head_range
=
self
.
blocksparse_params
.
active_head_range
,
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert
kv_cache
.
numel
()
==
0
\
or
prefill_meta
.
block_tables
is
None
\
or
prefill_meta
.
block_tables
.
numel
()
==
0
,
\
"Does not support prefix-enabled attention."
output
=
self
.
bs_attn
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
sm_scale
=
self
.
scale
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
self
.
blocksparse_params
.
max_seqlen
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
tp_rank
=
self
.
tp_rank
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_vert_stride
=
self
.
vert_stride
,
blocksparse_block_size
=
self
.
sparse_block_size
,
blocksparse_head_sliding_step
=
self
.
head_sliding_step
,
)
assert
output
is
not
None
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/differential_flash_attn.py
View file @
752c6ade
...
...
@@ -667,7 +667,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
@@ -680,9 +679,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
differential_flash_attention_config
self
.
used_shared_kv_cache
=
kv_sharing_target_layer_name
is
not
None
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
...
...
vllm/attention/backends/dual_chunk_flash_attn.py
View file @
752c6ade
...
...
@@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
vllm/attention/backends/flash_attn.py
View file @
752c6ade
...
...
@@ -4,7 +4,7 @@
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -615,7 +615,6 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
@@ -624,9 +623,6 @@ class FlashAttentionImpl(AttentionImpl):
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"FLASH_ATTN backend."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
...
...
vllm/attention/backends/flashinfer.py
View file @
752c6ade
...
...
@@ -999,7 +999,6 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
vllm/attention/backends/flashmla.py
View file @
752c6ade
...
...
@@ -3,7 +3,7 @@
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
@@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
assert
is_flashmla_supported
(),
\
"FlashMLA is not supported on this device"
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
"alibi_slopes, sliding_window, logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
vllm/attention/backends/mla/common.py
View file @
752c6ade
...
...
@@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
...
...
vllm/attention/backends/rocm_aiter_mla.py
View file @
752c6ade
...
...
@@ -3,7 +3,7 @@
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Type
,
Union
import
torch
...
...
@@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
...
...
@@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
"alibi_slopes, sliding_window, logits_soft_cap"
)
from
aiter
import
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
752c6ade
...
...
@@ -4,7 +4,7 @@
import
itertools
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -494,7 +494,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
@@ -507,9 +506,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger
.
warning_once
(
"Using irope in ROCm Flash Attention is not supported yet, it "
"will fail back to global attention for long context."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support blocksparse attention."
)
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
...
...
vllm/attention/backends/triton_mla.py
View file @
752c6ade
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
import
torch
...
...
@@ -35,7 +35,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
...
...
@@ -43,17 +42,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
"alibi_slopes, sliding_window, logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
vllm/attention/backends/xformers.py
View file @
752c6ade
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with xFormers and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
xformers
import
ops
as
xops
...
...
@@ -387,7 +387,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
...
@@ -396,9 +395,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"XFORMERS backend."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
logger
.
warning_once
(
"XFormers does not support logits soft cap. "
"Outputs may be slightly off."
)
...
...
vllm/attention/layer.py
View file @
752c6ade
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -74,7 +74,6 @@ class Attention(nn.Module):
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
per_layer_sliding_window
:
Optional
[
int
]
=
None
,
use_mla
:
bool
=
False
,
...
...
@@ -163,12 +162,11 @@ class Attention(nn.Module):
kv_cache_dtype
,
block_size
,
is_attention_free
,
blocksparse_params
is
not
None
,
use_mla
=
use_mla
)
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
extra_impl_args
)
self
.
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
self
.
dtype
=
dtype
...
...
vllm/attention/ops/blocksparse_attention/__init__.py
deleted
100644 → 0
View file @
881e3cbe
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
deleted
100644 → 0
View file @
881e3cbe
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
def
blocksparse_flash_attn_varlen_fwd
(
q
,
k
,
v
,
# (#tokens, n_heads, head_size)
cu_seqlens_k
,
cu_seqlens_q
,
sm_scale
,
sparse_layout
,
*
,
block_size
=
64
,
q_block_size
=
None
,
max_seqlen
=
None
):
# split q to blocks
assert
isinstance
(
sparse_layout
,
(
list
,
tuple
))
_
,
n_heads
,
head_size
=
q
.
shape
batch_size
=
cu_seqlens_k
.
size
(
0
)
-
1
q_block_size
=
q_block_size
or
block_size
assert
q
.
dim
()
==
k
.
dim
()
==
v
.
dim
()
==
3
assert
q
.
size
(
1
)
%
k
.
size
(
1
)
==
0
assert
q
.
size
(
2
)
==
k
.
size
(
2
)
# TODO(linxihui): allow k, v to have different head_size
assert
k
.
shape
==
v
.
shape
assert
cu_seqlens_k
.
dim
()
==
1
q_k_ratio
=
q
.
size
(
1
)
//
k
.
size
(
1
)
if
cu_seqlens_q
is
None
:
if
q
.
size
(
0
)
==
batch_size
:
# decoding only
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
cu_seqlens_k
.
dtype
,
device
=
cu_seqlens_k
.
device
,
)
elif
q
.
size
(
0
)
==
k
.
size
(
0
):
cu_seqlens_q
=
cu_seqlens_k
else
:
raise
ValueError
(
"cu_seqlens_q must be specified
\
if it mix of prefilling and decoding."
)
else
:
assert
cu_seqlens_k
.
size
(
0
)
==
cu_seqlens_q
.
size
(
0
)
# switch to use cpu to avoid too many kernel launches when iterated over
q_lens
=
(
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]).
cpu
()
k_lens
=
(
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
cpu
()
assert
torch
.
logical_or
(
q_lens
==
1
,
k_lens
==
q_lens
).
all
(),
(
"length of q should either be 1 (decoding) or same as k (prefilling)."
)
if
max_seqlen
:
assert
k_lens
.
max
()
<=
max_seqlen
n_blocks
=
(
q_lens
+
q_block_size
-
1
)
//
q_block_size
q_batch_ids
=
torch
.
tensor
(
[
i
for
i
,
n
in
enumerate
(
n_blocks
)
for
_
in
range
(
n
)],
dtype
=
cu_seqlens_q
.
dtype
,
device
=
cu_seqlens_q
.
device
,
)
q_start_sids
=
torch
.
tensor
(
[
i
*
q_block_size
for
n
in
n_blocks
for
i
in
range
(
n
)],
dtype
=
cu_seqlens_q
.
dtype
,
device
=
cu_seqlens_q
.
device
,
)
out
=
q
.
new_empty
(
q
.
shape
)
cu_seqlens_q
=
cu_seqlens_q
.
contiguous
()
cu_seqlens_k
=
cu_seqlens_k
.
contiguous
()
layout_crow_indices
,
layout_col_indices
=
sparse_layout
block_d
=
triton
.
next_power_of_2
(
head_size
)
decoding_only
=
(
q_lens
==
1
).
all
().
item
()
grid
=
(
len
(
q_start_sids
),
n_heads
,
1
)
_fwd_kernel_batch_inference
[
grid
](
q
,
k
,
v
,
out
,
sm_scale
,
cu_seqlens_q
[:
-
1
],
cu_seqlens_q
[
1
:],
cu_seqlens_k
[:
-
1
],
cu_seqlens_k
[
1
:],
q_batch_ids
,
q_start_sids
,
0
,
*
q
.
stride
(),
0
,
*
k
.
stride
(),
0
,
*
v
.
stride
(),
0
,
*
out
.
stride
(),
layout_crow_indices
,
layout_col_indices
,
*
layout_crow_indices
.
stride
(),
*
layout_col_indices
.
stride
(),
q_k_ratio
,
HAS_BATCH_DIM
=
False
,
D_HEAD
=
head_size
,
BLOCK_M
=
q_block_size
,
BLOCK_N
=
block_size
,
BLOCK_D
=
block_d
,
BLOCK_M_LOADING
=
(
16
if
decoding_only
else
q_block_size
),
# smaller for decoding
EVEN_D
=
block_d
==
head_size
,
num_warps
=
1
if
decoding_only
else
4
,
num_stages
=
3
)
return
out
@
triton
.
jit
def
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_col_idx
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
LAST_K_BLOCK
:
tl
.
constexpr
,
BLOCK_M_LOADING
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
EVEN_D
:
tl
.
constexpr
,
M_LT_N
:
tl
.
constexpr
,
):
k_block_id
=
tl
.
load
(
layout_col_ptr
+
off_h
*
layout_col_stride_h
+
k_block_col_idx
*
layout_col_stride_m
).
to
(
tl
.
int32
)
start_n
=
k_block_id
*
BLOCK_N
if
LAST_K_BLOCK
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
,
other
=
0.0
,
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
(
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
)
&
(
offs_d
[:,
None
]
<
D_HEAD
),
other
=
0.0
,
)
else
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_d
[:,
None
]
<
D_HEAD
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
|
M_LT_N
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
),
)
# flash-attn2
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
p
=
tl
.
math
.
exp2
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
# update m_i
m_i
=
m_ij
l_i
=
l_i
*
alpha
+
l_ij
p
=
p
.
to
(
Q
.
dtype
.
element_ty
)
# update acc
if
LAST_K_BLOCK
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_n
[:,
None
]
+
start_n
<
k_seqlen
,
other
=
0.0
,
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
(
offs_n
[:,
None
]
+
start_n
<
k_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0.0
,
)
else
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_d
[
None
,
:]
<
D_HEAD
,
other
=
0.0
)
acc
+=
tl
.
dot
(
p
,
v
)
return
acc
,
l_i
,
m_i
@
triton
.
heuristics
({
"M_LT_N"
:
lambda
kwargs
:
kwargs
[
"BLOCK_M"
]
<
kwargs
[
"BLOCK_N"
],
})
@
triton
.
jit
def
_fwd_kernel_batch_inference
(
Q
,
K
,
V
,
Out
,
sm_scale
,
q_batch_starts
,
q_batch_ends
,
k_batch_starts
,
k_batch_ends
,
q_batch_ids
,
q_start_sids
,
stride_qb
,
stride_qt
,
stride_qh
,
stride_qd
,
stride_kb
,
stride_kt
,
stride_kh
,
stride_kd
,
stride_vb
,
stride_vt
,
stride_vh
,
stride_vd
,
stride_ob
,
stride_ot
,
stride_oh
,
stride_od
,
layout_crow_ptr
,
layout_col_ptr
,
layout_crow_stride_h
,
layout_crow_stride_m
,
layout_col_stride_h
,
layout_col_stride_m
,
q_k_ratio
,
HAS_BATCH_DIM
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
BLOCK_M_LOADING
:
tl
.
constexpr
,
EVEN_D
:
tl
.
constexpr
,
M_LT_N
:
tl
.
constexpr
,
):
"""
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
TODO(linxihui):
Optimize grouped-attn
"""
off_zm
=
tl
.
program_id
(
0
)
off_h
=
tl
.
program_id
(
1
)
off_h_for_kv
=
off_h
//
q_k_ratio
if
HAS_BATCH_DIM
:
off_z
=
tl
.
program_id
(
2
)
Q
+=
off_z
*
stride_qb
K
+=
off_z
*
stride_kb
V
+=
off_z
*
stride_vb
Out
+=
off_z
*
stride_ob
start_m
=
off_zm
q_start_sid
=
start_m
*
BLOCK_M
# always 0 for decoding
else
:
off_z
=
tl
.
load
(
q_batch_ids
+
off_zm
).
to
(
tl
.
int32
)
# [0, 0, 0, 1]
q_start_sid
=
tl
.
load
(
q_start_sids
+
off_zm
)
start_m
=
q_start_sid
//
BLOCK_M
# q_sbid
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M_LOADING
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
q_cu_start
=
tl
.
load
(
q_batch_starts
+
off_z
).
to
(
tl
.
int32
)
q_seqlen
=
tl
.
load
(
q_batch_ends
+
off_z
).
to
(
tl
.
int32
)
-
q_cu_start
k_cu_start
=
tl
.
load
(
k_batch_starts
+
off_z
).
to
(
tl
.
int32
)
k_seqlen
=
tl
.
load
(
k_batch_ends
+
off_z
).
to
(
tl
.
int32
)
-
k_cu_start
past_len
=
k_seqlen
-
q_seqlen
Q
+=
q_cu_start
*
stride_qt
+
off_h
*
stride_qh
K
+=
k_cu_start
*
stride_kt
+
off_h_for_kv
*
stride_kh
V
+=
k_cu_start
*
stride_vt
+
off_h_for_kv
*
stride_vh
Out
+=
q_cu_start
*
stride_ot
+
off_h
*
stride_oh
q_pbid
=
(
past_len
+
q_start_sid
)
//
BLOCK_M
if
EVEN_D
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
other
=
0.0
,
)
else
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0.0
,
)
sparse_crow_ptr
=
(
layout_crow_ptr
+
off_h
*
layout_crow_stride_h
+
q_pbid
*
layout_crow_stride_m
)
# TODO(linxihui): load at once, with any Triton version
# that supports `tl.split`, e.g., Triton 3.0
k_block_start
=
tl
.
load
(
sparse_crow_ptr
).
to
(
tl
.
int32
)
k_block_end
=
tl
.
load
(
sparse_crow_ptr
+
1
).
to
(
tl
.
int32
)
m_i
=
tl
.
zeros
([
BLOCK_M_LOADING
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M_LOADING
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_D
],
dtype
=
tl
.
float32
)
k_ptrs
=
K
+
offs_n
[
None
,
:]
*
stride_kt
+
offs_d
[:,
None
]
*
stride_kd
v_ptrs
=
V
+
offs_n
[:,
None
]
*
stride_vt
+
offs_d
[
None
,
:]
*
stride_vd
sm_scale
*=
(
1.44269504
# 1/log2 as we use base2 for exponential and logarithm
)
for
k_block_col_idx
in
range
(
k_block_start
,
k_block_end
-
1
):
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_col_idx
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
False
,
BLOCK_M_LOADING
,
BLOCK_N
,
D_HEAD
,
EVEN_D
,
M_LT_N
,
)
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_end
-
1
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
True
,
BLOCK_M_LOADING
,
BLOCK_N
,
D_HEAD
,
EVEN_D
,
M_LT_N
,
)
# flash-attn 2
m_i
+=
tl
.
math
.
log2
(
l_i
)
acc
=
acc
/
l_i
[:,
None
]
# write output
if
EVEN_D
:
tl
.
store
(
Out
+
offs_m
[:,
None
]
*
stride_ot
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
)
else
:
tl
.
store
(
Out
+
offs_m
[:,
None
]
*
stride_ot
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment