Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fa973559
Commit
fa973559
authored
Aug 19, 2024
by
zhuwenwen
Browse files
Add fa pad conditions and automatic switching strategy
parent
a528f350
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
75 additions
and
13 deletions
+75
-13
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-2
vllm/envs.py
vllm/envs.py
+3
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+12
-1
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+12
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-3
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+7
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+7
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+24
-0
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
fa973559
...
...
@@ -228,7 +228,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8
192
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >
=
8
000
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
...
...
@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>
8
192
:
if
prefill_meta
.
max_prefill_seq_len
>
=
8
000
:
out
=
self
.
attn_func_triton
(
q
=
query
,
k
=
key
,
...
...
vllm/envs.py
View file @
fa973559
...
...
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
...
@@ -130,12 +131,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
...
...
vllm/model_executor/models/baichuan.py
View file @
fa973559
...
...
@@ -174,6 +174,11 @@ class BaiChuanAttention(nn.Module):
self
.
scaling
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -183,7 +188,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
...
...
@@ -333,6 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
...
vllm/model_executor/models/chatglm.py
View file @
fa973559
...
...
@@ -97,6 +97,11 @@ class GLMAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -106,7 +111,7 @@ class GLMAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
...
...
@@ -360,6 +365,12 @@ class ChatGLMForCausalLM(nn.Module):
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
...
vllm/model_executor/models/llama.py
View file @
fa973559
...
...
@@ -152,6 +152,11 @@ class LlamaAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -161,7 +166,7 @@ class LlamaAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -367,8 +372,8 @@ class LlamaForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/qwen.py
View file @
fa973559
...
...
@@ -114,6 +114,11 @@ class QWenAttention(nn.Module):
self
.
scaling
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -123,7 +128,7 @@ class QWenAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -246,7 +251,7 @@ class QWenLMHeadModel(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/qwen2.py
View file @
fa973559
...
...
@@ -144,6 +144,11 @@ class Qwen2Attention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -153,7 +158,7 @@ class Qwen2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -327,7 +332,7 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
...
...
vllm/worker/model_runner.py
View file @
fa973559
...
...
@@ -804,6 +804,30 @@ class ModelRunner:
max_num_seqs
=
min
(
max_num_seqs
,
int
(
max_num_batched_tokens
/
vlm_config
.
image_feature_size
))
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
for
group_id
in
range
(
1
):
seq_len
=
8000
if
vlm_config
is
None
:
seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_multi_modal_data
=
None
else
:
seq_data
,
dummy_multi_modal_data
=
MULTIMODAL_REGISTRY
\
.
dummy_data_for_profiling
(
seq_len
,
model_config
,
vlm_config
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_multi_modal_data
,
)
seqs
.
append
(
seq
)
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment