Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdd33b3f
Commit
bdd33b3f
authored
Jan 30, 2026
by
zhuwenwen
Browse files
update fa interface and kvcache
add prepare_so_files to prepare so
parent
63053820
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
271 additions
and
210 deletions
+271
-210
README.md
README.md
+3
-0
setup.py
setup.py
+34
-0
vllm/attention/layer.py
vllm/attention/layer.py
+4
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+230
-209
No files found.
README.md
View file @
bdd33b3f
...
...
@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
```
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
3.
跳过编译(适用于未改变csrc目录kernel并多次编译情况)
将编译后的so文件拷贝至csrc目录,并设置环境变量: export SKIP_VLLM_BUILD=1
#### 运行基础环境准备
1、使用上面基于光源pytorch2.9.0基础镜像环境
...
...
setup.py
View file @
bdd33b3f
...
...
@@ -13,6 +13,8 @@ import sys
import
sysconfig
from
pathlib
import
Path
from
shutil
import
which
import
tarfile
import
shutil
import
torch
from
packaging.version
import
Version
,
parse
...
...
@@ -36,6 +38,37 @@ skip_vllm_build = False
if
int
(
os
.
environ
.
get
(
'SKIP_VLLM_BUILD'
,
'0'
))
==
1
:
skip_vllm_build
=
True
def
prepare_so_files
():
source_dir
=
"csrc/so.tar.gz"
target_dir
=
"vllm"
if
not
os
.
path
.
exists
(
source_dir
):
print
(
f
"Warning:
{
source_dir
}
not found, skipping extraction"
)
return
print
(
f
"Preparing C extension files from
{
source_dir
}
..."
)
temp_dir
=
"temp_so_extract"
os
.
makedirs
(
temp_dir
,
exist_ok
=
True
)
try
:
with
tarfile
.
open
(
source_dir
,
"r:*"
)
as
tar
:
tar
.
extractall
(
temp_dir
)
for
root
,
dirs
,
files
in
os
.
walk
(
temp_dir
):
for
file
in
files
:
if
file
in
[
"_C.abi3.so"
,
"_moe_C.abi3.so"
,
"cumem_allocator.abi3.so"
]:
src_path
=
os
.
path
.
join
(
root
,
file
)
dst_path
=
os
.
path
.
join
(
target_dir
,
file
)
os
.
makedirs
(
os
.
path
.
dirname
(
dst_path
),
exist_ok
=
True
)
shutil
.
copy2
(
src_path
,
dst_path
)
print
(
f
"Copied
{
file
}
to
{
dst_path
}
"
)
finally
:
if
os
.
path
.
exists
(
temp_dir
):
shutil
.
rmtree
(
temp_dir
)
def
load_module_from_path
(
module_name
,
path
):
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
...
...
@@ -1109,6 +1142,7 @@ if _build_custom_ops():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
if
skip_vllm_build
:
prepare_so_files
()
package_data
=
{
"vllm"
:
[
"py.typed"
,
...
...
vllm/attention/layer.py
View file @
bdd33b3f
...
...
@@ -848,7 +848,10 @@ def unified_kv_cache_update(
layer_slot_mapping
,
)
return
torch
.
empty
(
0
,
device
=
kv_cache
.
device
,
dtype
=
kv_cache
.
dtype
)
if
current_platform
.
is_rocm
():
return
torch
.
empty
(
0
,
device
=
key
.
device
,
dtype
=
key
.
dtype
)
else
:
return
torch
.
empty
(
0
,
device
=
kv_cache
.
device
,
dtype
=
kv_cache
.
dtype
)
def
unified_kv_cache_update_fake
(
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
bdd33b3f
...
...
@@ -27,18 +27,18 @@ from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from
vllm.platforms
import
current_platform
if
is_flash_attn_varlen_func_available
():
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_sinks
,
flash_attn_varlen_func
,
get_scheduler_metadata
,
reshape_and_cache_flash
,
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
,
)
else
:
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_sinks
,
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
,
flash_attn_varlen_func
,
get_scheduler_metadata
,
reshape_and_cache_flash
,
)
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
,
get_layers_from_vllm_config
...
...
@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -121,31 +121,36 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
@
staticmethod
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers,
2,
block_size, num_kv_heads, head_size)
return
(
2
,
0
,
1
,
3
,
4
,
5
)
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
3
,
2
,
5
),
(
1
,
0
,
4
,
2
,
3
)
elif
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers,
2,
block_size, head_size)
return
(
2
,
4
,
0
,
1
,
3
,
5
)
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return
(
1
,
2
,
0
,
3
,
4
)
,
(
1
,
2
,
0
,
4
,
3
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
return
key_stride_order
,
value_
stride_order
else
:
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -154,36 +159,32 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
)
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
3
,
2
,
5
)
,
(
1
,
0
,
4
,
2
,
3
)
# (num_blocks, num_layers,
2,
block_size, num_kv_heads, head_size)
return
(
2
,
0
,
1
,
3
,
4
,
5
)
elif
cache_layout
==
"NHD"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return
(
1
,
2
,
0
,
3
,
4
),
(
1
,
2
,
0
,
4
,
3
)
# (num_blocks, num_kv_heads, num_layers,
2,
block_size, head_size)
return
(
2
,
4
,
0
,
1
,
3
,
5
)
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
return
stride_order
@
staticmethod
def
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
...
...
@@ -724,10 +725,10 @@ class FlashAttentionImpl(AttentionImpl):
)
# For decoder and cross-attention, use KV cache as before
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
if
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
# queries are quantized in the attention layer
...
...
@@ -745,12 +746,16 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
q_descale
=
None
k_descale
=
layer
.
_k_scale
v_descale
=
layer
.
_v_scale
else
:
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
self
.
num_kv_heads
)
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
)
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
)
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
)
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
)
if
self
.
dcp_world_size
>
1
:
self
.
_forward_with_dcp
(
...
...
@@ -772,8 +777,13 @@ class FlashAttentionImpl(AttentionImpl):
if
self
.
sliding_window
is
not
None
else
None
)
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
if
current_platform
.
is_rocm
():
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA SIZE:"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -793,16 +803,12 @@ class FlashAttentionImpl(AttentionImpl):
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
num_splits
=
attn_metadata
.
max_num_splits
,
#
num_splits=attn_metadata.max_num_splits,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
)
else
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA SIZE:"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -818,76 +824,42 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
scheduler_metadata
,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=attn_metadata.max_num_splits,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
num_splits
=
attn_metadata
.
max_num_splits
,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
)
return
output
# Cascade attention (rare case).
if
not
current_platform
.
is_rocm
():
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
max_num_splits
=
attn_metadata
.
max_num_splits
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
s_aux
=
self
.
sinks
,
)
else
:
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
2
,
#self.vllm_flash_attn_version,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
s_aux
=
self
.
sinks
,
)
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
max_num_splits
=
attn_metadata
.
max_num_splits
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
None
if
current_platform
.
is_rocm
()
else
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
s_aux
=
self
.
sinks
,
)
return
output
def
do_kv_cache_update
(
...
...
@@ -913,10 +885,10 @@ class FlashAttentionImpl(AttentionImpl):
):
return
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
if
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
...
...
@@ -925,18 +897,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if
not
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
if
current_platform
.
is_rocm
():
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
==
torch
.
float16
:
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
...
...
@@ -961,6 +922,17 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
def
_forward_with_dcp
(
...
...
@@ -989,28 +961,53 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size
=
(
list
(
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
None
)
context_attn_out
,
context_lse
=
flash_attn_varlen_func
(
q
=
query_across_dcp
,
k
=
key_cache
,
v
=
value_cache
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
attn_metadata
.
dcp_context_kv_lens
,
max_seqlen_k
=
attn_metadata
.
max_dcp_context_kv_len
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
if
current_platform
.
is_rocm
():
context_attn_out
,
context_lse
=
vllm_flash_attn_varlen_func
(
q
=
query_across_dcp
,
k
=
key_cache
,
v
=
value_cache
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
attn_metadata
.
dcp_context_kv_lens
,
max_seqlen_k
=
attn_metadata
.
max_dcp_context_kv_len
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
is_prefix_cache
=
True
,
)
else
:
context_attn_out
,
context_lse
=
flash_attn_varlen_func
(
q
=
query_across_dcp
,
k
=
key_cache
,
v
=
value_cache
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
attn_metadata
.
dcp_context_kv_lens
,
max_seqlen_k
=
attn_metadata
.
max_dcp_context_kv_len
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor
,
context_lse_cor
=
cp_lse_ag_out_rs
(
context_attn_out
,
...
...
@@ -1020,26 +1017,49 @@ class FlashAttentionImpl(AttentionImpl):
)
context_lse_cor
=
context_lse_cor
.
transpose
(
0
,
1
).
contiguous
()
query_attn_out
,
query_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
if
current_platform
.
is_rocm
():
query_attn_out
,
query_lse
=
vllm_flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
else
:
query_attn_out
,
query_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
assert
context_attn_out_cor
.
shape
==
query_attn_out
.
shape
assert
context_lse_cor
.
shape
==
query_lse
.
shape
merge_attn_states
(
...
...
@@ -1094,8 +1114,8 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size
=
(
list
(
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
None
)
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
if
current_platform
.
is_rocm
():
vllm_
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -1109,14 +1129,18 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
num_splits
=
1
if
self
.
batch_invariant_enabled
else
0
,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache
=
False
,
)
else
:
vllm_
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -1130,15 +1154,11 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache
=
False
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
num_splits
=
1
if
self
.
batch_invariant_enabled
else
0
,
)
return
output
...
...
@@ -1259,11 +1279,12 @@ def cascade_attention(
assert
common_prefix_len
%
block_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
block_size
assert
num_common_kv_blocks
>
0
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
if
not
current_platform
.
is_rocm
():
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process shared prefix.
if
not
current_platform
.
is_rocm
():
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
if
current_platform
.
is_rocm
():
prefix_output
,
prefix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -1279,16 +1300,17 @@ def cascade_attention(
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
q_descale
=
q_descale
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
if
v_descale
is
not
None
else
None
,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux
=
s_aux
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache
=
True
,
)
else
:
prefix_output
,
prefix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -1303,22 +1325,21 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
#
fa_version=fa_version,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux
=
s_aux
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache
=
True
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process suffix per query.
if
not
current_platform
.
is_rocm
():
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
if
current_platform
.
is_rocm
():
suffix_output
,
suffix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -1334,13 +1355,14 @@ def cascade_attention(
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
q_descale
=
q_descale
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
if
v_descale
is
not
None
else
None
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache
=
True
,
)
else
:
suffix_output
,
suffix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -1355,12 +1377,11 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
#
fa_version=fa_version,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache
=
True
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
)
# Merge prefix and suffix outputs, and store the result in output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment