Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82f1ffdf
Commit
82f1ffdf
authored
Sep 09, 2024
by
zhuwenwen
Browse files
add cutlass fa and support bloom nn layout
parent
b5160479
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
26 deletions
+114
-26
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+53
-25
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+54
-0
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
82f1ffdf
...
@@ -290,7 +290,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -290,7 +290,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
attn_func_triton
=
flash_attn_varlen_func
self
.
attn_func_triton
=
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_c
k
=
flash_attn_varlen_func
self
.
attn_func_c
u
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
logger
.
debug
(
"When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
...
@@ -428,17 +428,32 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -428,17 +428,32 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
)
)
else
:
else
:
out
=
self
.
attn_func_ck
(
if
envs
.
VLLM_USE_CL_FLASH_ATTN
:
q
=
query
,
out
=
self
.
attn_func_cu
(
k
=
key
,
q
=
query
,
v
=
value
,
k
=
key
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
v
=
value
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
causal
=
True
,
softmax_scale
=
self
.
scale
,
)
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
out
=
self
.
attn_func_cu
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
else
:
# out = self.attn_func(
# out = self.attn_func(
# query,
# query,
...
@@ -490,19 +505,32 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -490,19 +505,32 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
,
attn_masks
,
)
)
else
:
else
:
out
=
self
.
attn_func
(
if
envs
.
VLLM_USE_CL_FLASH_ATTN
:
q
=
query
,
out
=
self
.
attn_func
(
k
=
key
,
q
=
query
,
v
=
value
,
k
=
key
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
v
=
value
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
causal
=
True
,
softmax_scale
=
self
.
scale
,
# window_size=self.sliding_window,
causal
=
True
,
# alibi_slopes=self.alibi_slopes,
window_size
=
self
.
sliding_window
,
)
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
# common code for prefill
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
...
...
vllm/envs.py
View file @
82f1ffdf
...
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
...
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_CL_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
...
@@ -196,6 +197,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -196,6 +197,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_CL_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/model_loader/utils.py
View file @
82f1ffdf
...
@@ -22,7 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
...
@@ -22,7 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
...
...
vllm/model_executor/models/bloom.py
View file @
82f1ffdf
...
@@ -22,6 +22,8 @@ from typing import Iterable, List, Optional, Tuple
...
@@ -22,6 +22,8 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BloomConfig
from
transformers
import
BloomConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
...
@@ -40,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -40,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
@@ -113,6 +117,10 @@ class BloomAttention(nn.Module):
...
@@ -113,6 +117,10 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -284,6 +292,15 @@ class BloomForCausalLM(nn.Module):
...
@@ -284,6 +292,15 @@ class BloomForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -342,3 +359,40 @@ class BloomForCausalLM(nn.Module):
...
@@ -342,3 +359,40 @@ class BloomForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attention.query_key_value.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_4h_to_h.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attention.query_key_value.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"self_attention.query_key_value.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment