Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd363067
Commit
bd363067
authored
Jun 05, 2025
by
lizhigong
Browse files
Merge branch 'v0.8.5.post1-dev' into v0.8.5-zero_overhead
parents
87ef4618
d36deb1a
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
231 additions
and
349 deletions
+231
-349
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-95
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+1
-48
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+0
-90
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+105
-22
vllm/utils.py
vllm/utils.py
+4
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+119
-93
No files found.
vllm/model_executor/models/llama.py
View file @
bd363067
...
...
@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
W8a8GetCacheJSON
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
...
...
@@ -356,14 +355,12 @@ class LlamaModel(nn.Module):
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
# self.use_lm_nn = os.environ.get('LM_NN') == '1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -517,97 +514,7 @@ class LlamaModel(nn.Module):
else
:
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
# if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
# lay_key_words = [
# "self_attn.qkv_proj.qweight",
# "self_attn.o_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.down_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
#当为triton支持推理的时候不能进行处理
if
self
.
quant_method
==
"compressed_tensors"
:
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
n
=
weight_data
.
shape
[
0
]
# k=weight_data.shape[1]
# #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
# json_file=self.tritonsingleton.get_w8a8json_name(n,k)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
return
loaded_params
...
...
vllm/model_executor/models/qwen.py
View file @
bd363067
...
...
@@ -39,7 +39,6 @@ from .utils import (is_pp_missing_parameter,
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
class
QWenMLP
(
nn
.
Module
):
...
...
@@ -291,13 +290,11 @@ class QWenBaseModel(nn.Module):
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
compute_logits
(
self
,
...
...
@@ -384,51 +381,7 @@ class QWenBaseModel(nn.Module):
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
# if self.quant_method == "awq":
# os.environ['LM_NN'] = '0'
# lay_key_words = [
# "attn.c_attn.qweight",
# "attn.c_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.c_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if
self
.
quant_method
==
"compressed_tensors"
:
os
.
environ
[
'LM_NN'
]
=
'0'
lay_key_words
=
[
"attn.c_attn.weight"
,
...
...
vllm/model_executor/models/qwen2.py
View file @
bd363067
...
...
@@ -63,7 +63,6 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
logger
=
init_logger
(
__name__
)
...
...
@@ -338,13 +337,11 @@ class Qwen2Model(nn.Module):
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -485,93 +482,6 @@ class Qwen2Model(nn.Module):
else
:
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
# if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
# lay_key_words = [
# "self_attn.qkv_proj.qweight",
# "self_attn.o_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.down_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if
self
.
quant_method
==
"compressed_tensors"
:
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
n
=
weight_data
.
shape
[
0
]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
return
loaded_params
...
...
vllm/platforms/rocm.py
View file @
bd363067
...
...
@@ -110,15 +110,16 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return
(
on_mi250_mi300
()
and
(
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
(
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
not
(
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
and
envs
.
VLLM_ROCM_USE_AITER
))
return
False
# return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
# or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and (head_size == 64 or head_size == 128)
# and (block_size == 16 or block_size == 32)
# and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
# and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
# and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
# and envs.VLLM_ROCM_USE_AITER))
class
RocmPlatform
(
Platform
):
...
...
@@ -205,20 +206,102 @@ class RocmPlatform(Platform):
# f" The selected backend, {selected_backend.name},"
# f"is not MLA type while requested for MLA backend.")
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
envs
.
VLLM_USE_V1
:
logger
.
info
(
"Using Triton Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend"
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
cls
.
has_device_capability
(
90
):
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
if
envs
.
VLLM_FLASH_ATTN_BACKEND
:
if
use_v1
:
if
selected_backend
==
_Backend
.
FLASHINFER
:
raise
ValueError
(
"FlashInfer backend on V1 engine is not supported"
)
if
selected_backend
==
_Backend
.
TRITON_ATTN_VLLM_V1
:
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend"
)
if
cls
.
has_device_capability
(
80
):
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend"
)
if
selected_backend
==
_Backend
.
FLASHINFER
:
raise
ValueError
(
"FlashInfer backend is not supported"
)
elif
selected_backend
==
_Backend
.
XFORMERS
:
raise
ValueError
(
"XFormers backend is not supported"
)
elif
selected_backend
==
_Backend
.
FLASH_ATTN
:
pass
elif
selected_backend
:
raise
ValueError
(
f
"Invalid attention backend for
{
cls
.
device_name
}
, "
f
"with use_v1:
{
use_v1
}
use_mla:
{
use_mla
}
"
)
target_backend
=
_Backend
.
FLASH_ATTN
if
not
cls
.
has_device_capability
(
80
):
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs."
)
raise
ValueError
(
"XFormers backend is not supported"
)
elif
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info
(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
# raise ValueError("XFormers backend is not supported")
pass
elif
block_size
%
16
!=
0
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16."
)
raise
ValueError
(
"XFormers backend is not supported"
)
# FlashAttn is valid for the model, checking if the package is
# installed.
if
target_backend
==
_Backend
.
FLASH_ATTN
:
try
:
import
flash_attn
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
,
flash_attn_supports_fp8
)
supported_sizes
=
\
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
supported_sizes
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for head size %d."
,
head_size
)
raise
ValueError
(
"XFormers backend is not supported"
)
fp8_kv_cache
=
(
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
))
if
(
fp8_kv_cache
and
not
flash_attn_supports_fp8
()):
logger
.
info
(
"Cannot use FlashAttention backend for FP8 KV cache."
)
logger
.
warning
(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER"
)
raise
ValueError
(
"XFormers backend is not supported"
)
except
ImportError
:
logger
.
info
(
"Cannot use FlashAttention-2 backend because the "
"flash_attn package is not found. "
"Make sure that flash_attn was built and installed "
"(on by default)."
)
raise
ValueError
(
"XFormers backend is not supported"
)
if
target_backend
==
_Backend
.
XFORMERS
:
raise
ValueError
(
"XFormers backend is not supported"
)
logger
.
info
(
"Using Flash Attention backend."
)
return
"vllm.attention.backends.flash_attn.FlashAttentionBackend"
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
logger
.
info
(
"Using ROCmFlashAttention backend."
)
return
"vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"
# noqa: E501
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
envs
.
VLLM_USE_V1
:
logger
.
info
(
"Using Triton Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend"
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
cls
.
has_device_capability
(
90
):
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
logger
.
info
(
"Using ROCmFlashAttention backend."
)
return
"vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"
# noqa: E501
@
classmethod
@
lru_cache
(
maxsize
=
8
)
...
...
vllm/utils.py
View file @
bd363067
...
...
@@ -1725,7 +1725,9 @@ class W8a8GetCacheJSON:
json_folder_path
=
current_folder_path
+
'/../lmslim/configs/w8a8'
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
json_folder_path
))
self
.
triton_json_dict
=
[]
self
.
triton_json_dict
=
{}
self
.
triton_json_list
=
[]
self
.
weight_shapes
=
[]
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
...
...
@@ -1823,6 +1825,7 @@ class W8a8GetCacheJSON:
'kpack'
:
int
(
sub_value
[
"kpack"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
]),
'enable_mmacfuse'
:
int
(
sub_value
[
'enable_mmacfuse'
]),
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
bd363067
...
...
@@ -24,9 +24,11 @@ if TYPE_CHECKING:
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_
cuda
():
if
not
current_platform
.
is_
rocm
():
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
get_scheduler_metadata
)
else
:
from
flash_attn
import
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
logger
=
init_logger
(
__name__
)
...
...
@@ -603,57 +605,79 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
# scheduler_metadata=scheduler_metadata,
)
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
)
return
output
if
not
current_platform
.
is_rocm
():
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
)
return
output
else
:
raise
ValueError
(
"cascade attention is not supported on rocm"
)
def
use_cascade_attention
(
...
...
@@ -761,56 +785,58 @@ def cascade_attention(
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process shared prefix.
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_prefix_query_lens
,
seqused_k
=
prefix_kv_lens
,
max_seqlen_q
=
num_tokens
,
max_seqlen_k
=
common_prefix_len
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
window_size
=
sliding_window
,
block_table
=
block_table
[:
1
],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
if
not
current_platform
.
is_rocm
():
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_prefix_query_lens
,
seqused_k
=
prefix_kv_lens
,
max_seqlen_q
=
num_tokens
,
max_seqlen_k
=
common_prefix_len
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
window_size
=
sliding_window
,
block_table
=
block_table
[:
1
],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process suffix per query.
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
suffix_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
-
common_prefix_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
sliding_window
,
block_table
=
block_table
[:,
num_common_kv_blocks
:],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
if
not
current_platform
.
is_rocm
():
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
suffix_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
-
common_prefix_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
sliding_window
,
block_table
=
block_table
[:,
num_common_kv_blocks
:],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment