Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82f1ffdf
Commit
82f1ffdf
authored
Sep 09, 2024
by
zhuwenwen
Browse files
add cutlass fa and support bloom nn layout
parent
b5160479
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
26 deletions
+114
-26
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+53
-25
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+54
-0
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
82f1ffdf
...
...
@@ -290,7 +290,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
attn_func_triton
=
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_c
k
=
flash_attn_varlen_func
self
.
attn_func_c
u
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
...
...
@@ -428,7 +428,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal
=
True
,
)
else
:
out
=
self
.
attn_func_ck
(
if
envs
.
VLLM_USE_CL_FLASH_ATTN
:
out
=
self
.
attn_func_cu
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -438,6 +439,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
out
=
self
.
attn_func_cu
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
# out = self.attn_func(
...
...
@@ -490,6 +505,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
,
)
else
:
if
envs
.
VLLM_USE_CL_FLASH_ATTN
:
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
...
...
@@ -500,8 +516,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
# window_size=self.sliding_window,
# alibi_slopes=self.alibi_slopes,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
# common code for prefill
...
...
vllm/envs.py
View file @
82f1ffdf
...
...
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_CL_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
...
...
@@ -196,6 +197,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_CL_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/model_loader/utils.py
View file @
82f1ffdf
...
...
@@ -22,7 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
...
...
vllm/model_executor/models/bloom.py
View file @
82f1ffdf
...
...
@@ -22,6 +22,8 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
BloomConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
...
...
@@ -40,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
...
@@ -113,6 +117,10 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -285,6 +293,15 @@ class BloomForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -342,3 +359,40 @@ class BloomForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attention.query_key_value.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_4h_to_h.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attention.query_key_value.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"self_attention.query_key_value.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment