Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e56ae5e
Commit
8e56ae5e
authored
Jul 25, 2024
by
zhuwenwen
Browse files
skip window_size and alibi_slopes because ck fa is not supported, and add triton fa
parent
0e640807
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
16 deletions
+29
-16
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+29
-16
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
8e56ae5e
...
...
@@ -271,11 +271,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
# from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
# flash_attn_varlen_func)
self
.
attn_func
=
triton_attention
# flash_attn_varlen_func
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
# self.attn_func = triton_attention
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
...
...
@@ -391,17 +392,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
False
)
# type: ignore
# out = self.attn_func(
# query,
# key,
# value,
# prefill_meta.seq_lens,
# num_tokens,
# self.num_heads,
# self.head_size,
# self.scale,
# attn_masks,
# )
out
=
self
.
attn_func
(
query
,
key
,
value
,
prefill_meta
.
seq_
lens
,
num_tokens
,
self
.
num_heads
,
self
.
head_size
,
self
.
scale
,
attn_masks
,
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_
start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
...
...
@@ -439,8 +452,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
#
window_size=self.sliding_window,
#
alibi_slopes=self.alibi_slopes,
)
# common code for prefill
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment