Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b40f2ffc
Commit
b40f2ffc
authored
Aug 19, 2024
by
zhuwenwen
Browse files
Add fa pad conditions and automatic switching strategy
parent
4d821524
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
79 additions
and
14 deletions
+79
-14
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+3
-3
vllm/envs.py
vllm/envs.py
+3
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+13
-1
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+12
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-3
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+7
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+7
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+26
-0
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
b40f2ffc
...
@@ -276,7 +276,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -276,7 +276,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8
192
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8
000
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
if
self
.
use_flash_attn_auto
:
...
@@ -286,7 +286,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -286,7 +286,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_ck
=
flash_attn_varlen_func
self
.
attn_func_ck
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8
192
, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
logger
.
debug
(
"When SEQ_LEN > 8
000
, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
# triton_attention)
...
@@ -410,7 +410,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -410,7 +410,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata
.
seq_lens
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
False
)
# type: ignore
make_attn_mask
=
False
)
# type: ignore
if
self
.
use_flash_attn_auto
:
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>
8
192
:
if
prefill_meta
.
max_prefill_seq_len
>
8
000
:
out
=
self
.
attn_func_triton
(
out
=
self
.
attn_func_triton
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
...
...
vllm/envs.py
View file @
b40f2ffc
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
@@ -178,12 +179,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -178,12 +179,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# Internal flag to enable Dynamo graph capture
# Internal flag to enable Dynamo graph capture
...
...
vllm/model_executor/models/baichuan.py
View file @
b40f2ffc
...
@@ -178,6 +178,12 @@ class BaiChuanAttention(nn.Module):
...
@@ -178,6 +178,12 @@ class BaiChuanAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -186,7 +192,7 @@ class BaiChuanAttention(nn.Module):
...
@@ -186,7 +192,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
if
self
.
postion_embedding
!=
"ALIBI"
:
...
@@ -341,6 +347,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
...
@@ -341,6 +347,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
...
vllm/model_executor/models/chatglm.py
View file @
b40f2ffc
...
@@ -100,6 +100,11 @@ class GLMAttention(nn.Module):
...
@@ -100,6 +100,11 @@ class GLMAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -108,7 +113,7 @@ class GLMAttention(nn.Module):
...
@@ -108,7 +113,7 @@ class GLMAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
...
@@ -366,6 +371,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
...
@@ -366,6 +371,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
...
vllm/model_executor/models/llama.py
View file @
b40f2ffc
...
@@ -167,6 +167,11 @@ class LlamaAttention(nn.Module):
...
@@ -167,6 +167,11 @@ class LlamaAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -175,7 +180,7 @@ class LlamaAttention(nn.Module):
...
@@ -175,7 +180,7 @@ class LlamaAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
@@ -417,8 +422,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -417,8 +422,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
else
:
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
self
.
quant_method
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/qwen.py
View file @
b40f2ffc
...
@@ -117,6 +117,11 @@ class QWenAttention(nn.Module):
...
@@ -117,6 +117,11 @@ class QWenAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -125,7 +130,7 @@ class QWenAttention(nn.Module):
...
@@ -125,7 +130,7 @@ class QWenAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
vllm/model_executor/models/qwen2.py
View file @
b40f2ffc
...
@@ -149,6 +149,11 @@ class Qwen2Attention(nn.Module):
...
@@ -149,6 +149,11 @@ class Qwen2Attention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -157,7 +162,7 @@ class Qwen2Attention(nn.Module):
...
@@ -157,7 +162,7 @@ class Qwen2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
vllm/worker/model_runner.py
View file @
b40f2ffc
...
@@ -900,6 +900,32 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -900,6 +900,32 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs
=
1
max_num_seqs
=
1
batch_size
=
0
batch_size
=
0
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
for
group_id
in
range
(
1
):
seq_len
=
8000
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
)
# Having more tokens is over-conservative but otherwise fine
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but got:
{
len
(
seq_data
.
prompt_token_ids
)
}
"
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_multi_modal_data
,
)
seqs
.
append
(
seq
)
for
group_id
in
range
(
max_num_seqs
):
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment