Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1863c926
Commit
1863c926
authored
Jul 06, 2024
by
zhuwenwen
Browse files
Use triton fa by default
parent
b6247705
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1334 additions
and
15 deletions
+1334
-15
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+26
-14
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
+1307
-0
vllm/envs.py
vllm/envs.py
+1
-1
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
1863c926
...
...
@@ -229,9 +229,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
self
.
attn_func
=
triton_attention
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
...
...
@@ -325,17 +327,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
out
,
_
=
self
.
attn_func
(
query
,
key
,
value
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
prefill_meta
.
max_prefill_seq_len
,
True
,
self
.
scale
,
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
...
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
0 → 100644
View file @
1863c926
This diff is collapsed.
Click to expand it.
vllm/envs.py
View file @
1863c926
...
...
@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment