Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdd33b3f
Commit
bdd33b3f
authored
Jan 30, 2026
by
zhuwenwen
Browse files
update fa interface and kvcache
add prepare_so_files to prepare so
parent
63053820
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
271 additions
and
210 deletions
+271
-210
README.md
README.md
+3
-0
setup.py
setup.py
+34
-0
vllm/attention/layer.py
vllm/attention/layer.py
+4
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+230
-209
No files found.
README.md
View file @
bdd33b3f
...
@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
...
@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
```
```
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
3.
跳过编译(适用于未改变csrc目录kernel并多次编译情况)
将编译后的so文件拷贝至csrc目录,并设置环境变量: export SKIP_VLLM_BUILD=1
#### 运行基础环境准备
#### 运行基础环境准备
1、使用上面基于光源pytorch2.9.0基础镜像环境
1、使用上面基于光源pytorch2.9.0基础镜像环境
...
...
setup.py
View file @
bdd33b3f
...
@@ -13,6 +13,8 @@ import sys
...
@@ -13,6 +13,8 @@ import sys
import
sysconfig
import
sysconfig
from
pathlib
import
Path
from
pathlib
import
Path
from
shutil
import
which
from
shutil
import
which
import
tarfile
import
shutil
import
torch
import
torch
from
packaging.version
import
Version
,
parse
from
packaging.version
import
Version
,
parse
...
@@ -36,6 +38,37 @@ skip_vllm_build = False
...
@@ -36,6 +38,37 @@ skip_vllm_build = False
if
int
(
os
.
environ
.
get
(
'SKIP_VLLM_BUILD'
,
'0'
))
==
1
:
if
int
(
os
.
environ
.
get
(
'SKIP_VLLM_BUILD'
,
'0'
))
==
1
:
skip_vllm_build
=
True
skip_vllm_build
=
True
def
prepare_so_files
():
source_dir
=
"csrc/so.tar.gz"
target_dir
=
"vllm"
if
not
os
.
path
.
exists
(
source_dir
):
print
(
f
"Warning:
{
source_dir
}
not found, skipping extraction"
)
return
print
(
f
"Preparing C extension files from
{
source_dir
}
..."
)
temp_dir
=
"temp_so_extract"
os
.
makedirs
(
temp_dir
,
exist_ok
=
True
)
try
:
with
tarfile
.
open
(
source_dir
,
"r:*"
)
as
tar
:
tar
.
extractall
(
temp_dir
)
for
root
,
dirs
,
files
in
os
.
walk
(
temp_dir
):
for
file
in
files
:
if
file
in
[
"_C.abi3.so"
,
"_moe_C.abi3.so"
,
"cumem_allocator.abi3.so"
]:
src_path
=
os
.
path
.
join
(
root
,
file
)
dst_path
=
os
.
path
.
join
(
target_dir
,
file
)
os
.
makedirs
(
os
.
path
.
dirname
(
dst_path
),
exist_ok
=
True
)
shutil
.
copy2
(
src_path
,
dst_path
)
print
(
f
"Copied
{
file
}
to
{
dst_path
}
"
)
finally
:
if
os
.
path
.
exists
(
temp_dir
):
shutil
.
rmtree
(
temp_dir
)
def
load_module_from_path
(
module_name
,
path
):
def
load_module_from_path
(
module_name
,
path
):
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
...
@@ -1109,6 +1142,7 @@ if _build_custom_ops():
...
@@ -1109,6 +1142,7 @@ if _build_custom_ops():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
if
skip_vllm_build
:
if
skip_vllm_build
:
prepare_so_files
()
package_data
=
{
package_data
=
{
"vllm"
:
[
"vllm"
:
[
"py.typed"
,
"py.typed"
,
...
...
vllm/attention/layer.py
View file @
bdd33b3f
...
@@ -848,6 +848,9 @@ def unified_kv_cache_update(
...
@@ -848,6 +848,9 @@ def unified_kv_cache_update(
layer_slot_mapping
,
layer_slot_mapping
,
)
)
if
current_platform
.
is_rocm
():
return
torch
.
empty
(
0
,
device
=
key
.
device
,
dtype
=
key
.
dtype
)
else
:
return
torch
.
empty
(
0
,
device
=
kv_cache
.
device
,
dtype
=
kv_cache
.
dtype
)
return
torch
.
empty
(
0
,
device
=
kv_cache
.
device
,
dtype
=
kv_cache
.
dtype
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
bdd33b3f
...
@@ -27,18 +27,18 @@ from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
...
@@ -27,18 +27,18 @@ from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
is_flash_attn_varlen_func_available
():
if
is_flash_attn_varlen_func_available
():
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.backends.fa_utils
import
(
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_sinks
,
flash_attn_supports_sinks
,
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
,
get_scheduler_metadata
,
reshape_and_cache_cuda
,
reshape_and_cache_flash
,
)
)
else
:
else
:
from
vllm.v1.attention.backends.fa_utils
import
(
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_sinks
,
flash_attn_supports_sinks
,
vllm_flash_attn_varlen_func
,
flash_attn_varlen_func
,
reshape_and_cache_cuda
,
get_scheduler_metadata
,
reshape_and_cache_flash
,
)
)
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
,
get_layers_from_vllm_config
...
@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
return
FlashAttentionMetadataBuilder
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -121,31 +121,36 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -121,31 +121,36 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
if
block_size
%
16
!=
0
:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
(
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
# `stride_order` indicates the permutation that gets
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers,
2,
block_size, num_kv_heads, head_size)
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
2
,
0
,
1
,
3
,
4
,
5
)
return
(
1
,
0
,
3
,
2
,
5
),
(
1
,
0
,
4
,
2
,
3
)
elif
cache_layout
==
"NHD"
:
elif
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers,
2,
block_size, head_size)
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return
(
2
,
4
,
0
,
1
,
3
,
5
)
return
(
1
,
2
,
0
,
3
,
4
)
,
(
1
,
2
,
0
,
4
,
3
)
elif
cache_layout
==
"HND"
:
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
else
:
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
return
key_stride_order
,
value_
stride_order
else
:
else
:
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -154,36 +159,32 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -154,36 +159,32 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
(
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
tuple
[
int
,
...]
,
tuple
[
int
,
...]]
:
)
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
# (num_blocks, num_layers,
2,
block_size, num_kv_heads, head_size)
return
(
1
,
0
,
3
,
2
,
5
)
,
(
1
,
0
,
4
,
2
,
3
)
return
(
2
,
0
,
1
,
3
,
4
,
5
)
elif
cache_layout
==
"NHD"
:
elif
cache_layout
==
"NHD"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
# (num_blocks, num_kv_heads, num_layers,
2,
block_size, head_size)
return
(
1
,
2
,
0
,
3
,
4
),
(
1
,
2
,
0
,
4
,
3
)
return
(
2
,
4
,
0
,
1
,
3
,
5
)
elif
cache_layout
==
"HND"
:
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
else
:
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
return
stride_order
@
staticmethod
@
staticmethod
def
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
def
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
...
@@ -724,10 +725,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -724,10 +725,10 @@ class FlashAttentionImpl(AttentionImpl):
)
)
# For decoder and cross-attention, use KV cache as before
# For decoder and cross-attention, use KV cache as before
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
# queries are quantized in the attention layer
# queries are quantized in the attention layer
...
@@ -745,7 +746,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -745,7 +746,11 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
block_table
=
attn_metadata
.
block_table
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
q_descale
=
None
k_descale
=
layer
.
_k_scale
v_descale
=
layer
.
_v_scale
else
:
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
self
.
num_kv_heads
)
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
self
.
num_kv_heads
)
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
)
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
)
...
@@ -772,8 +777,13 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -772,8 +777,13 @@ class FlashAttentionImpl(AttentionImpl):
if
self
.
sliding_window
is
not
None
if
self
.
sliding_window
is
not
None
else
None
else
None
)
)
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA SIZE:"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -793,16 +803,12 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -793,16 +803,12 @@ class FlashAttentionImpl(AttentionImpl):
q_descale
=
q_descale
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
num_splits
=
attn_metadata
.
max_num_splits
,
#
num_splits=attn_metadata.max_num_splits,
s_aux
=
self
.
sinks
,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
)
)
else
:
else
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
flash_attn_varlen_func
(
print
(
"PA SIZE:"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -818,21 +824,16 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -818,21 +824,16 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
block_table
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
scheduler_metadata
,
scheduler_metadata
=
scheduler_metadata
,
# fa_version=self.vllm_flash_attn_version,
fa_version
=
self
.
vllm_flash_attn_version
,
# q_descale=q_descale,
q_descale
=
q_descale
,
# k_descale=k_descale,
k_descale
=
k_descale
,
# v_descale=v_descale,
v_descale
=
v_descale
,
q_descale
=
None
,
num_splits
=
attn_metadata
.
max_num_splits
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=attn_metadata.max_num_splits,
s_aux
=
self
.
sinks
,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
)
)
return
output
return
output
# Cascade attention (rare case).
# Cascade attention (rare case).
if
not
current_platform
.
is_rocm
():
cascade_attention
(
cascade_attention
(
output
[:
num_actual_tokens
],
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
...
@@ -854,36 +855,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -854,36 +855,7 @@ class FlashAttentionImpl(AttentionImpl):
fa_version
=
self
.
vllm_flash_attn_version
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
q_descale
=
None
if
current_platform
.
is_rocm
()
else
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
s_aux
=
self
.
sinks
,
)
else
:
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
2
,
#self.vllm_flash_attn_version,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
v_descale
=
layer
.
_v_scale
,
s_aux
=
self
.
sinks
,
s_aux
=
self
.
sinks
,
...
@@ -913,10 +885,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -913,10 +885,10 @@ class FlashAttentionImpl(AttentionImpl):
):
):
return
return
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Skip this if sharing KV cache with an earlier attention layer.
...
@@ -925,8 +897,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -925,8 +897,10 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# actual tokens.
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
==
torch
.
float16
:
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
...
@@ -934,11 +908,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -934,11 +908,10 @@ class FlashAttentionImpl(AttentionImpl):
slot_mapping
,
slot_mapping
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
)
)
else
:
else
:
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
==
torch
.
float16
:
from
vllm.v1.attention.backends.fa_utils
import
reshape_and_cache_cuda
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
reshape_and_cache_cuda
(
key
,
key
,
value
,
value
,
...
@@ -947,11 +920,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -947,11 +920,10 @@ class FlashAttentionImpl(AttentionImpl):
slot_mapping
,
slot_mapping
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
layer
.
_v_scale
,
)
)
else
:
else
:
from
vllm.v1.attention.backends.fa_utils
import
reshape_and_cache_cuda
reshape_and_cache_flash
(
reshape_and_cache_cuda
(
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
...
@@ -989,6 +961,31 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -989,6 +961,31 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size
=
(
sliding_window_size
=
(
list
(
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
None
list
(
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
None
)
)
if
current_platform
.
is_rocm
():
context_attn_out
,
context_lse
=
vllm_flash_attn_varlen_func
(
q
=
query_across_dcp
,
k
=
key_cache
,
v
=
value_cache
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
attn_metadata
.
dcp_context_kv_lens
,
max_seqlen_k
=
attn_metadata
.
max_dcp_context_kv_len
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
is_prefix_cache
=
True
,
)
else
:
context_attn_out
,
context_lse
=
flash_attn_varlen_func
(
context_attn_out
,
context_lse
=
flash_attn_varlen_func
(
q
=
query_across_dcp
,
q
=
query_across_dcp
,
k
=
key_cache
,
k
=
key_cache
,
...
@@ -1020,6 +1017,28 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -1020,6 +1017,28 @@ class FlashAttentionImpl(AttentionImpl):
)
)
context_lse_cor
=
context_lse_cor
.
transpose
(
0
,
1
).
contiguous
()
context_lse_cor
=
context_lse_cor
.
transpose
(
0
,
1
).
contiguous
()
if
current_platform
.
is_rocm
():
query_attn_out
,
query_lse
=
vllm_flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
out
=
None
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
return_softmax_lse
=
True
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
else
:
query_attn_out
,
query_lse
=
flash_attn_varlen_func
(
query_attn_out
,
query_lse
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
...
@@ -1040,6 +1059,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -1040,6 +1059,7 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
)
)
assert
context_attn_out_cor
.
shape
==
query_attn_out
.
shape
assert
context_attn_out_cor
.
shape
==
query_attn_out
.
shape
assert
context_lse_cor
.
shape
==
query_lse
.
shape
assert
context_lse_cor
.
shape
==
query_lse
.
shape
merge_attn_states
(
merge_attn_states
(
...
@@ -1094,8 +1114,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -1094,8 +1114,8 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size
=
(
sliding_window_size
=
(
list
(
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
None
list
(
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
None
)
)
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
vllm_
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -1109,14 +1129,18 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -1109,14 +1129,18 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
# fa_version=self.vllm_flash_attn_version,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
# q_descale=layer._q_scale.expand(descale_shape),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
# k_descale=layer._k_scale.expand(descale_shape),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
# v_descale=layer._v_scale.expand(descale_shape),
num_splits
=
1
if
self
.
batch_invariant_enabled
else
0
,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache
=
False
,
)
)
else
:
else
:
vllm_
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -1130,15 +1154,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -1130,15 +1154,11 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
window_size
=
sliding_window_size
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
# fa_version=self.vllm_flash_attn_version,
fa_version
=
self
.
vllm_flash_attn_version
,
# q_descale=layer._q_scale.expand(descale_shape),
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
# k_descale=layer._k_scale.expand(descale_shape),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
# v_descale=layer._v_scale.expand(descale_shape),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
q_descale
=
None
,
num_splits
=
1
if
self
.
batch_invariant_enabled
else
0
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache
=
False
,
)
)
return
output
return
output
...
@@ -1259,11 +1279,12 @@ def cascade_attention(
...
@@ -1259,11 +1279,12 @@ def cascade_attention(
assert
common_prefix_len
%
block_size
==
0
assert
common_prefix_len
%
block_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
block_size
num_common_kv_blocks
=
common_prefix_len
//
block_size
assert
num_common_kv_blocks
>
0
assert
num_common_kv_blocks
>
0
if
not
current_platform
.
is_rocm
():
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process shared prefix.
# Process shared prefix.
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
prefix_output
,
prefix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -1279,16 +1300,17 @@ def cascade_attention(
...
@@ -1279,16 +1300,17 @@ def cascade_attention(
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
q_descale
=
q_descale
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
k_descale
=
k_descale
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
v_descale
=
v_descale
if
v_descale
is
not
None
else
None
,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
# enabling its effect during the final attention merge.
s_aux
=
s_aux
,
s_aux
=
s_aux
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache
=
True
,
)
)
else
:
else
:
prefix_output
,
prefix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -1303,22 +1325,21 @@ def cascade_attention(
...
@@ -1303,22 +1325,21 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
scheduler_metadata
=
prefix_scheduler_metadata
,
#
fa_version=fa_version,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
# enabling its effect during the final attention merge.
s_aux
=
s_aux
,
s_aux
=
s_aux
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
is_prefix_cache
=
True
,
)
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process suffix per query.
# Process suffix per query.
if
not
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
suffix_output
,
suffix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -1334,13 +1355,14 @@ def cascade_attention(
...
@@ -1334,13 +1355,14 @@ def cascade_attention(
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
q_descale
=
q_descale
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
k_descale
=
k_descale
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
v_descale
=
v_descale
if
v_descale
is
not
None
else
None
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache
=
True
,
)
)
else
:
else
:
suffix_output
,
suffix_lse
,
_
=
vllm_
flash_attn_varlen_func
(
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -1355,12 +1377,11 @@ def cascade_attention(
...
@@ -1355,12 +1377,11 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
scheduler_metadata
=
suffix_scheduler_metadata
,
#
fa_version=fa_version,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
is_prefix_cache
=
True
,
)
)
# Merge prefix and suffix outputs, and store the result in output.
# Merge prefix and suffix outputs, and store the result in output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment