Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdc70899
Commit
bdc70899
authored
May 13, 2025
by
zhuwenwen
Browse files
support cutlass prefix-cache
parent
09e372e7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
17 deletions
+71
-17
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+65
-17
vllm/envs.py
vllm/envs.py
+6
-0
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
bdc70899
...
@@ -575,6 +575,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -575,6 +575,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
fa_attn_func
=
flash_attn_varlen_func
self
.
fa_attn_func
=
flash_attn_varlen_func
if
not
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
from
flash_attn
import
vllm_flash_attn_varlen_func
self
.
fa_prefix_attn_func
=
vllm_flash_attn_varlen_func
logger
.
debug
(
"Using CUTLASS FA in ROCmBackend"
)
logger
.
debug
(
"Using CUTLASS FA in ROCmBackend"
)
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
self
.
use_naive_attn
=
True
...
@@ -843,24 +847,68 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -843,24 +847,68 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
else
:
# prefix-enabled attention -
# prefix-enabled attention -
# not applicable for encoder-only models
# not applicable for encoder-only models
version_key
=
triton_key
()
# if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
# self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
if
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
query
,
version_key
=
triton_key
()
key
,
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
value
,
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
self
.
kv_cache_dtype
,
query
,
key_cache
,
key
,
value_cache
,
value
,
prefill_meta
.
block_tables
,
self
.
kv_cache_dtype
,
prefill_meta
.
query_start_loc
,
key_cache
,
prefill_meta
.
seq_lens_tensor
,
value_cache
,
prefill_meta
.
max_query_len
,
prefill_meta
.
block_tables
,
self
.
alibi_slopes
,
prefill_meta
.
query_start_loc
,
self
.
sliding_window
[
0
],
prefill_meta
.
seq_lens_tensor
,
layer
.
_k_scale
,
prefill_meta
.
max_query_len
,
layer
.
_v_scale
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
assert
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Only decoder-only models support prefix caching"
)
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
query_start_loc
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
descale_shape
=
(
prefill_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
'''
k_cache
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i
'''
num_blocks
,
num_kv_heads
,
head_size_div_x
,
block_size
,
x
=
key_cache
.
shape
head_size
=
head_size_div_x
*
x
key_cache_flash
=
key_cache
.
permute
(
0
,
3
,
1
,
2
,
4
)
# [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash
=
key_cache_flash
.
reshape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
# value_cache
value_cache_flash
=
value_cache
.
permute
(
0
,
3
,
1
,
2
)
# [num_blocks, block_size, num_kv_heads, head_size]
output
[:
num_prefill_tokens
]
=
self
.
fa_prefix_attn_func
(
# noqa
q
=
query
,
k
=
key_cache_flash
,
v
=
value_cache_flash
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
)
# Skip decode phase for encoder-only models
# Skip decode phase for encoder-only models
if
(
decode_meta
:
=
attn_metadata
.
decode_metadata
)
and
(
if
(
decode_meta
:
=
attn_metadata
.
decode_metadata
)
and
(
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
):
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
):
...
...
vllm/envs.py
View file @
bdc70899
...
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
...
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
...
@@ -272,6 +273,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -272,6 +273,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control if vllm should use triton prefix flash attention
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP"
:
"VLLM_USE_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment