Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
408f663a
Commit
408f663a
authored
Sep 13, 2024
by
zhuwenwen
Browse files
remove the automatic switching strategy of fa
parent
aa1e273a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
121 deletions
+32
-121
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+23
-75
vllm/attention/selector.py
vllm/attention/selector.py
+7
-6
vllm/envs.py
vllm/envs.py
+2
-8
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+0
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+0
-27
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
408f663a
...
...
@@ -281,31 +281,19 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func_triton
=
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_cu
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
...
...
@@ -414,47 +402,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
False
)
# type: ignore
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>
8000
:
out
=
self
.
attn_func_triton
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
if
envs
.
VLLM_USE_CL_FLASH_ATTN
:
out
=
self
.
attn_func_cu
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
out
=
self
.
attn_func_cu
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
# out = self.attn_func(
# query,
# key,
...
...
@@ -466,17 +414,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# self.scale,
# attn_masks,
# )
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
...
vllm/attention/selector.py
View file @
408f663a
...
...
@@ -202,12 +202,13 @@ def which_attn_to_use(
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
# if selected_backend == _Backend.ROCM_FLASH:
# if current_platform.get_device_capability()[0] != 9:
# # not Instinct series GPUs.
# logger.info("flash_attn is not supported on NAVI GPUs.")
# else:
# logger.info("%s is not supported in AMD GPUs.", selected_backend)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
# if current_platform.get_device_capability()[0] != 9:
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
return
_Backend
.
ROCM_FLASH
# FlashAttn in NVIDIA GPUs.
...
...
vllm/envs.py
View file @
408f663a
...
...
@@ -13,7 +13,6 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_CL_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
LOCAL_RANK
:
int
=
0
...
...
@@ -196,17 +195,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Tru
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_CL_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_CL_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
...
...
vllm/model_executor/model_loader/utils.py
View file @
408f663a
...
...
@@ -23,7 +23,6 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
...
...
@@ -35,10 +34,6 @@ def get_model_architecture(
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
if
any
(
arch
in
architectures
for
arch
in
use_triton_fa_architectures
):
os
.
environ
[
'VLLM_USE_TRITON_FLASH_ATTN'
]
=
'1'
os
.
environ
[
'VLLM_USE_FLASH_ATTN_AUTO'
]
=
'0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
...
...
vllm/worker/model_runner.py
View file @
408f663a
...
...
@@ -1179,33 +1179,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs
=
1
batch_size
=
0
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
for
group_id
in
range
(
1
):
if
max_num_batched_tokens
>=
8000
:
seq_len
=
8000
else
:
seq_len
=
max_num_batched_tokens
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_multi_modal_data
,
)
seqs
.
append
(
seq
)
max_num_batched_tokens
-=
seq_len
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment