Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8824ae6a
Commit
8824ae6a
authored
Sep 18, 2025
by
王敏
Browse files
merge 092-dev分支近期修改
parents
f9f1887d
c0707728
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1222 additions
and
145 deletions
+1222
-145
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+4
-11
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+7
-3
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
+26
-5
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+187
-78
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+5
-1
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+2
-1
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+4
-1
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+23
-12
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+8
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+100
-24
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+27
-7
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
+250
-0
vllm/v1/attention/backends/mla/concatv4_decode_only.py
vllm/v1/attention/backends/mla/concatv4_decode_only.py
+248
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+12
-2
vllm/v1/attention/backends/mla/test_concat.py
vllm/v1/attention/backends/mla/test_concat.py
+317
-0
No files found.
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
8824ae6a
...
...
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_
2_marlin_weight
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_
weight_repack_impl
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
...
...
@@ -22,6 +22,7 @@ try:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
class
MarlinMoeWorkspace
:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...
...
@@ -205,16 +206,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
w1_marlin_list
=
[]
for
e
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w13_weight
[
e
])
w1_marlin_list
.
append
(
w1_marlin_in
)
layer
.
w13_weight
=
Parameter
(
torch
.
stack
(
w1_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
w2_marlin_list
=
[]
for
e
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w2_weight
[
e
])
w2_marlin_list
.
append
(
w2_marlin_in
)
layer
.
w2_weight
=
Parameter
(
torch
.
stack
(
w2_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
apply_ep
(
#dp+ep
self
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
8824ae6a
...
...
@@ -176,15 +176,19 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
supports_router_weight
=
not
layer
.
apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation
=
layer
.
activation
==
"silu"
#暂时只支持bw
device_name
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
supports_device
=
"BW"
in
device_name
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
supports_shape
=
hidden_size
%
128
==
0
and
\
intermediate_size_per_partition
%
max
(
64
,
group_size
)
==
0
supports_group_size
=
group_size
in
[
-
1
,
32
,
64
,
128
]
#暂时只支持64
supports_group_size
=
group_size
in
[
64
]
return
supports_shape
and
supports_group_size
and
\
supports_router_weight
and
supports_activation
supports_router_weight
and
supports_activation
and
supports_device
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
...
...
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
View file @
8824ae6a
...
...
@@ -2,6 +2,12 @@
import
torch
import
numpy
as
np
try
:
from
lightop
import
awq_marlin_repack_w4a8
use_lightop
=
True
except
Exception
:
use_lightop
=
False
def
unpack_int8_to_int4
(
tensor_int8
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
...
...
@@ -54,12 +60,12 @@ def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
q_w
=
q_w
.
reshape
((
-
1
,
weight_perm
.
numel
()))[:,
weight_perm
].
reshape
(
q_w
.
shape
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
np
.
uint32
)
q_packed
=
np
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
np
.
uint32
)
q_w
=
q_w
.
contiguous
().
to
(
torch
.
int32
)
M
,
N
=
q_w
.
shape
assert
N
%
pack_factor
==
0
,
f
"size_n (
{
N
}
) must be divisible by pack_factor (
{
pack_factor
}
)"
q_packed
=
torch
.
zeros
((
M
,
N
//
pack_factor
),
dtype
=
torch
.
int32
,
device
=
orig_device
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
4
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
np
.
int32
)).
to
(
orig_device
)
q_packed
+=
q_w
[:,
i
::
pack_factor
]
<<
(
4
*
i
)
return
q_packed
...
...
@@ -70,3 +76,18 @@ def w4a8_2_marlin_weight(w4a8_w):
marlin_q_w
=
marlin_weights
(
full_w4a8_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
)
return
marlin_q_w
def
w4a8_weight_repack_impl
(
input
):
if
use_lightop
:
size_batch
=
input
.
shape
[
0
]
size_n
=
input
.
shape
[
1
]
size_k
=
input
.
shape
[
2
]
*
2
output
=
torch
.
zeros
((
size_batch
,
size_k
//
32
,
size_n
*
4
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
awq_marlin_repack_w4a8
(
input
,
output
,
size_batch
,
size_k
,
size_n
)
else
:
w_marlin_list
=
[]
for
e
in
range
(
input
.
shape
[
0
]):
w_marlin_in
=
w4a8_2_marlin_weight
(
input
[
e
])
w_marlin_list
.
append
(
w_marlin_in
)
output
=
torch
.
stack
(
w_marlin_list
,
dim
=
0
)
return
output
\ No newline at end of file
vllm/model_executor/layers/rotary_embedding.py
View file @
8824ae6a
...
...
@@ -40,6 +40,8 @@ from vllm.platforms import current_platform
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
if
current_platform
.
is_rocm
():
from
flash_attn.layers.rotary
import
apply_rotary_emb
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
8824ae6a
...
...
@@ -96,11 +96,21 @@ class DeepseekV2MLP(nn.Module):
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
def
forward
(
self
,
x
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
False
):
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
new_resi
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
,
new_resi
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekV2MoE
(
nn
.
Module
):
...
...
@@ -153,10 +163,10 @@ class DeepseekV2MoE(nn.Module):
self
.
n_local_physical_experts
)
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
dp_size
=
get_dp_group
().
world_size
self
.
use_all2all_ep
=
envs
.
VLLM_USE_ALLTOALL_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
moe_cls
=
FusedMoE
if
not
self
.
use_all2all_ep
else
EPMoE
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
...
...
@@ -179,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_all2all_ep
else
EPSharedExperts
self
.
shared_experts
=
shared_expert_cls
(
#
shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
...
...
@@ -195,13 +205,21 @@ class DeepseekV2MoE(nn.Module):
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
not
self
.
use_all2all_ep
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
not
self
.
use_all2all_ep
:
...
...
@@ -215,9 +233,9 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
else
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
,
new_resi
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
not
self
.
use_all2all_ep
:
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
...
...
@@ -235,8 +253,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
else
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
...
...
@@ -437,19 +457,36 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
if
envs
.
USE_FUSED_RMS_QUANT
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
eps
=
config
.
rms_norm_eps
,
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
eps
=
config
.
rms_norm_eps
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
)
else
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
...
...
@@ -524,31 +561,60 @@ class DeepseekV2MLAAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
quant_args
=
input_quant_args
,
update_hd
=
False
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
...
@@ -623,47 +689,90 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
if
envs
.
USE_FUSED_RMS_QUANT
:
# Fix residual FP16 overflow
residual_fix_overflow
=
False
assert
self
.
input_layernorm
.
has_weight
is
True
if
residual
is
None
:
residual
=
hidden_states
hidden_states
,
_
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
None
)
residual_fix_overflow
=
True
else
:
hidden_states
,
new_residual
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
residual
)
residual
=
new_residual
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
new_resi
=
self
.
mlp
(
hidden_states
,
self
.
post_attention_layernorm
.
weight
.
data
,
residual
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
...
...
@@ -984,7 +1093,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank
# So we simply skip it
continue
if
self
.
use_all2all_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
...
...
vllm/model_executor/models/keye.py
View file @
8824ae6a
...
...
@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model
,
is_pp_missing_parameter
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
get_vit_attn_backend
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
@@ -330,7 +331,10 @@ def apply_rotary_pos_emb_flashatt(
cos
=
cos
.
chunk
(
2
,
dim
=-
1
)[
0
].
contiguous
()
sin
=
sin
.
chunk
(
2
,
dim
=-
1
)[
0
].
contiguous
()
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
if
not
current_platform
.
is_rocm
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
else
:
from
flash_attn.layers.rotary
import
apply_rotary_emb
q_embed
=
apply_rotary_emb
(
q
.
float
(),
cos
.
float
(),
sin
.
float
()).
type_as
(
q
)
k_embed
=
apply_rotary_emb
(
k
.
float
(),
cos
.
float
(),
sin
.
float
()).
type_as
(
k
)
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
8824ae6a
...
...
@@ -436,7 +436,8 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
x
.
to
(
memory_format
=
torch
.
channels_last_3d
)
if
os
.
environ
.
get
(
'PYTORCH_MIOPEN_SUGGEST_NDHWC'
)
==
'1'
:
x
=
x
.
to
(
memory_format
=
torch
.
channels_last_3d
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
hidden_size
)
return
x
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
8824ae6a
...
...
@@ -246,6 +246,8 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor,
apply_rotary_emb
=
apply_rotary_emb_torch
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
if
current_platform
.
is_rocm
():
from
flash_attn.layers.rotary
import
apply_rotary_emb
output
=
apply_rotary_emb
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
...
...
@@ -464,7 +466,8 @@ class Qwen2VisionPatchEmbed(nn.Module):
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
x
.
to
(
memory_format
=
torch
.
channels_last_3d
)
if
os
.
environ
.
get
(
'PYTORCH_MIOPEN_SUGGEST_NDHWC'
)
==
'1'
:
x
=
x
.
to
(
memory_format
=
torch
.
channels_last_3d
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
embed_dim
)
return
x
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
8824ae6a
...
...
@@ -287,18 +287,29 @@ def tbo_split_and_execute_model(
attn_metadata_left
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_left
,
input_split
.
req_ids_left
,
0
)
attn_metadata_right
=
prepare_tbo_atten_metadata
(
runner
,
input_split
.
scheduler_output_right
,
input_split
.
req_ids_right
,
input_split
.
req_num_left
)
model_output
=
tbo_model_executable_v1
(
runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
finished_sending
,
finished_recving
=
None
,
None
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
tbo_model_executable_v1
(
runner
,
attn_metadata_left
,
attn_metadata_right
,
num_input_tokens_left
,
num_input_tokens_right
,
num_tokens_across_dp
,
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
runner
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
runner
.
get_finished_kv_transfers
(
scheduler_output
))
#finished_sending, finished_recving = None, None
else
:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
...
...
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
8824ae6a
...
...
@@ -162,6 +162,14 @@ def init_two_batch_overlap():
tbo_obj_v1
=
TwoBatchOverlap
()
tbo_obj_v1
.
init_tbo_thread
()
def
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
):
from
vllm.attention.layer
import
maybe_save_kv_layer_to_connector
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
if
tid
==
tbo_obj_v1
.
left_tid
:
return
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
def
tbo_all_reduce_v1
(
obj
):
if
envs
.
VLLM_ENABLE_TBO
and
tbo_obj_v1
!=
None
and
tbo_obj_v1
.
tbo_running
:
tid
=
threading
.
get_ident
()
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
8824ae6a
...
...
@@ -669,30 +669,56 @@ class FlashAttentionImpl(AttentionImpl):
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
)
if
not
current_platform
.
is_rocm
():
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
)
else
:
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
key_cache
,
value_cache
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
2
,
#self.vllm_flash_attn_version,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
)
return
output
...
...
@@ -825,6 +851,31 @@ def cascade_attention(
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
else
:
prefix_output
,
prefix_lse
,
_
=
vllm_flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_prefix_query_lens
,
seqused_k
=
prefix_kv_lens
,
max_seqlen_q
=
num_tokens
,
max_seqlen_k
=
common_prefix_len
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
window_size
=
sliding_window
,
block_table
=
block_table
[:
1
],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
is_prefix_cache
=
True
,
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
...
...
@@ -853,6 +904,31 @@ def cascade_attention(
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
else
:
suffix_output
,
suffix_lse
,
_
=
vllm_flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
suffix_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
-
common_prefix_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
sliding_window
,
block_table
=
block_table
[:,
num_common_kv_blocks
:],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
is_prefix_cache
=
True
,
)
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
8824ae6a
...
...
@@ -216,6 +216,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -777,10 +778,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
and
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
)
if
not
current_platform
.
is_rocm
():
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
else
:
self
.
_pad_v
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
...
...
@@ -921,8 +925,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
if
envs
.
VLLM_USE_TRITON_CAT
:
if
k_nope
.
shape
[
0
]
>
1024
:
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
...
...
@@ -977,7 +989,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
if
envs
.
VLLM_USE_TRITON_CAT
:
if
k_nope
.
shape
[
0
]
>
1024
:
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
...
...
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
0 → 100644
View file @
8824ae6a
import
triton
import
triton.language
as
tl
import
torch
from
functools
import
reduce
import
pytest
import
torch
import
math
@
pytest
.
mark
.
parametrize
(
"shape_pair,dim"
,
[
(((
4
,
8
,
512
),
(
4
,
8
,
64
)),
2
),
(((
8
,
8
,
512
),
(
8
,
8
,
64
)),
2
),
(((
16
,
8
,
512
),
(
16
,
8
,
64
)),
2
),
(((
32
,
8
,
512
),
(
32
,
8
,
64
)),
2
),
(((
64
,
8
,
512
),
(
64
,
8
,
64
)),
2
),
(((
128
,
8
,
512
),
(
128
,
8
,
64
)),
2
),
(((
256
,
8
,
512
),
(
256
,
8
,
64
)),
2
),
(((
512
,
8
,
512
),
(
512
,
8
,
64
)),
2
),
(((
672
,
8
,
512
),
(
672
,
8
,
64
)),
2
),
(((
768
,
8
,
512
),
(
768
,
8
,
64
)),
2
),
(((
896
,
8
,
512
),
(
896
,
8
,
64
)),
2
),
(((
1024
,
8
,
512
),
(
1024
,
8
,
64
)),
2
),
(((
4
,
16
,
512
),
(
4
,
16
,
64
)),
2
),
(((
8
,
16
,
512
),
(
8
,
16
,
64
)),
2
),
(((
16
,
16
,
512
),
(
16
,
16
,
64
)),
2
),
(((
32
,
16
,
512
),
(
32
,
16
,
64
)),
2
),
(((
64
,
16
,
512
),
(
64
,
16
,
64
)),
2
),
(((
128
,
16
,
512
),
(
128
,
16
,
64
)),
2
),
(((
256
,
16
,
512
),
(
256
,
16
,
64
)),
2
),
(((
512
,
16
,
512
),
(
512
,
16
,
64
)),
2
),
(((
672
,
16
,
512
),
(
672
,
16
,
64
)),
2
),
(((
768
,
16
,
512
),
(
768
,
16
,
64
)),
2
),
(((
896
,
16
,
512
),
(
896
,
16
,
64
)),
2
),
(((
1024
,
16
,
512
),
(
1024
,
16
,
64
)),
2
),
(((
4
,
32
,
512
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
512
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
512
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
512
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
512
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
512
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
512
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
512
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
512
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
512
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
512
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
512
),
(
1024
,
32
,
64
)),
2
),
(((
4
,
32
,
128
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
128
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
128
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
128
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
128
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
128
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
128
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
128
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
128
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
128
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
128
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
128
),
(
1024
,
32
,
64
)),
2
),
])
def
test_concat_Acc
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
x
=
torch
.
randn
(
*
shape1
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
randn
(
*
shape2
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_helper
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
@
triton
.
jit
def
concat_kernel_prefill
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
# 获取当前block的索引
for
sub_section_index
in
range
(
Per_block
//
2
):
sub_section_offset
=
block_idx
*
Per_block
+
sub_section_index
*
2
if
sub_section_offset
<=
section_num
-
1
:
C_section_start
=
C_ptr
+
sub_section_offset
*
C_section_numel
A_section_start
=
A_ptr
+
sub_section_offset
*
A_section_numel
B_section_start
=
B_ptr
+
sub_section_offset
*
B_section_numel
Arrange_doubleA
=
tl
.
arange
(
0
,
256
)
mask
=
Arrange_doubleA
<
(
256
)
Arrange2
=
(
tl
.
arange
(
0
,
128
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
256
)
val_from_A
=
tl
.
load
(
A_section_start
+
Arrange_doubleA
)
tensorAsn
=
tl
.
full
((
256
,),
0
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
256
,),
(
C_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleA
<
A_section_numel
,
tensorAsn
,
tensorAsn2
)
off
=
Arrange2
+
tensor_offsets
tl
.
store
(
C_section_start
+
off
,
val_from_A
,
mask
=
mask
)
Arrange_doubleB
=
tl
.
arange
(
0
,
128
)
mask
=
Arrange_doubleB
<
(
B_section_numel
*
2
)
val_from_B
=
tl
.
load
(
B_section_start
+
Arrange_doubleB
,
mask
=
mask
)
Arrange3
=
(
tl
.
arange
(
0
,
64
)[
None
,:]
+
tl
.
arange
(
0
,
2
)[:,
None
]).
reshape
(
128
)
tensorAsn
=
tl
.
full
((
128
,),
A_section_numel
,
tl
.
int32
)
tensorAsn2
=
tl
.
full
((
128
,),
(
C_section_numel
+
A_section_numel
-
1
),
tl
.
int32
)
tensor_offsets
=
tl
.
where
(
Arrange_doubleB
<
B_section_numel
,
tensorAsn
,
tensorAsn2
)
tl
.
store
(
C_section_start
+
Arrange3
+
tensor_offsets
,
val_from_B
)
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
for
sub_section_index
in
range
(
Per_block
):
sub_offset
=
block_idx
*
Per_block
+
sub_section_index
if
sub_offset
<=
section_num
-
1
:
C_ptr_block_start
=
C_ptr
+
sub_offset
*
C_section_numel
A_ptr_block_start
=
A_ptr
+
sub_offset
*
A_section_numel
B_ptr_block_start
=
B_ptr
+
sub_offset
*
B_section_numel
for
offset
in
range
(
0
,
A_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
A_section_numel
val_from_A
=
tl
.
load
(
A_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
offset_idx
,
val_from_A
,
mask
=
mask
)
for
offset
in
range
(
0
,
B_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
B_section_numel
val_from_B
=
tl
.
load
(
B_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
A_section_numel
+
offset_idx
,
val_from_B
,
mask
=
mask
)
def
concat_helper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
A
=
A
.
contiguous
()
B
=
B
.
contiguous
()
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
if
dim
!=
0
:
block_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
Per_block
=
1
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
stride
(
dim
-
1
),
B
.
stride
(
dim
-
1
),
C
.
stride
(
dim
-
1
)
#case prefill
if
(
A
.
shape
[
2
]
==
128
and
B
.
shape
[
2
]
==
64
and
A
.
shape
[
0
]
>
16
):
Per_block
=
8
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel_prefill
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
else
:
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
128
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
96
)
or
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
2
]
==
512
and
A
.
shape
[
0
]
>
64
):
Per_block
=
2
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
BLOCK_SIZE
=
1024
)
return
C
assert
False
,
"not support"
configs
=
[]
configs
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'size'
],
x_vals
=
[
4
,
8
,
16
,
32
,
64
,
96
,
128
,
256
,
512
,
768
,
1024
],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'triton'
,
'torch'
],
line_names
=
[
'Triton'
,
'Torch'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
)],
ylabel
=
's'
,
plot_name
=
'concat-dim2'
,
args
=
{
"dim"
:
2
},
),
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
8
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
8
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_16
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
16
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
16
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_32
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
512
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_prefill
(
size
,
provider
,
dim
):
x
=
torch
.
rand
([
size
,
32
,
128
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
rand
([
size
,
32
,
64
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
if
__name__
==
'__main__'
:
# benchmark.run(save_path="./triton_test_8",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill
.
run
(
save_path
=
"./triton_test_prefill"
,
print_data
=
True
)
\ No newline at end of file
vllm/v1/attention/backends/mla/concatv4_decode_only.py
0 → 100644
View file @
8824ae6a
import
triton
import
triton.language
as
tl
import
torch
from
functools
import
reduce
import
pytest
import
torch
import
math
@
pytest
.
mark
.
parametrize
(
"shape_pair,dim"
,
[
(((
4
,
8
,
512
),
(
4
,
8
,
64
)),
2
),
(((
8
,
8
,
512
),
(
8
,
8
,
64
)),
2
),
(((
16
,
8
,
512
),
(
16
,
8
,
64
)),
2
),
(((
32
,
8
,
512
),
(
32
,
8
,
64
)),
2
),
(((
64
,
8
,
512
),
(
64
,
8
,
64
)),
2
),
(((
128
,
8
,
512
),
(
128
,
8
,
64
)),
2
),
(((
256
,
8
,
512
),
(
256
,
8
,
64
)),
2
),
(((
512
,
8
,
512
),
(
512
,
8
,
64
)),
2
),
(((
672
,
8
,
512
),
(
672
,
8
,
64
)),
2
),
(((
768
,
8
,
512
),
(
768
,
8
,
64
)),
2
),
(((
896
,
8
,
512
),
(
896
,
8
,
64
)),
2
),
(((
1024
,
8
,
512
),
(
1024
,
8
,
64
)),
2
),
(((
4
,
16
,
512
),
(
4
,
16
,
64
)),
2
),
(((
8
,
16
,
512
),
(
8
,
16
,
64
)),
2
),
(((
16
,
16
,
512
),
(
16
,
16
,
64
)),
2
),
(((
32
,
16
,
512
),
(
32
,
16
,
64
)),
2
),
(((
64
,
16
,
512
),
(
64
,
16
,
64
)),
2
),
(((
128
,
16
,
512
),
(
128
,
16
,
64
)),
2
),
(((
256
,
16
,
512
),
(
256
,
16
,
64
)),
2
),
(((
512
,
16
,
512
),
(
512
,
16
,
64
)),
2
),
(((
672
,
16
,
512
),
(
672
,
16
,
64
)),
2
),
(((
768
,
16
,
512
),
(
768
,
16
,
64
)),
2
),
(((
896
,
16
,
512
),
(
896
,
16
,
64
)),
2
),
(((
1024
,
16
,
512
),
(
1024
,
16
,
64
)),
2
),
(((
4
,
32
,
512
),
(
4
,
32
,
64
)),
2
),
(((
8
,
32
,
512
),
(
8
,
32
,
64
)),
2
),
(((
16
,
32
,
512
),
(
16
,
32
,
64
)),
2
),
(((
32
,
32
,
512
),
(
32
,
32
,
64
)),
2
),
(((
64
,
32
,
512
),
(
64
,
32
,
64
)),
2
),
(((
128
,
32
,
512
),
(
128
,
32
,
64
)),
2
),
(((
256
,
32
,
512
),
(
256
,
32
,
64
)),
2
),
(((
512
,
32
,
512
),
(
512
,
32
,
64
)),
2
),
(((
672
,
32
,
512
),
(
672
,
32
,
64
)),
2
),
(((
768
,
32
,
512
),
(
768
,
32
,
64
)),
2
),
(((
896
,
32
,
512
),
(
896
,
32
,
64
)),
2
),
(((
1024
,
32
,
512
),
(
1024
,
32
,
64
)),
2
),
])
def
test_concat_Acc
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
M
=
shape1
[
0
]
N
=
shape1
[
1
]
x_sizes
=
[
M
,
N
,
512
]
x_strides
=
[
512
,
512
*
M
,
1
]
x_max_index
=
M
*
N
*
512
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
# print("形状:", x.shape) # [4, 8, 512]
# print("步幅:", x.stride()) # (1536, 192, 1)
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
1536
*
(
N
//
8
),
192
,
1
]
y_max_index
=
1536
*
(
N
//
8
)
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_helper
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block
,
section_num
,
M
,
N
,
Astride_0
,
Astride_1
,
Astride_2
,
Bstride_0
,
Bstride_1
,
Bstride_2
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
for
sub_section_index
in
range
(
Per_block
):
sub_offset
=
block_idx
*
Per_block
+
sub_section_index
M_idx
=
sub_offset
//
N
N_idx
=
sub_offset
%
N
if
sub_offset
<=
section_num
-
1
:
C_ptr_block_start
=
C_ptr
+
sub_offset
*
C_section_numel
A_ptr_block_start
=
A_ptr
+
M_idx
*
Astride_0
+
N_idx
*
Astride_1
B_ptr_block_start
=
B_ptr
+
M_idx
*
Bstride_0
+
N_idx
*
Bstride_1
for
offset
in
range
(
0
,
A_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
A_section_numel
val_from_A
=
tl
.
load
(
A_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
offset_idx
,
val_from_A
,
mask
=
mask
)
for
offset
in
range
(
0
,
B_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
B_section_numel
val_from_B
=
tl
.
load
(
B_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
A_section_numel
+
offset_idx
,
val_from_B
,
mask
=
mask
)
def
concat_helper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
if
dim
!=
0
:
block_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
Per_block
=
1
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
shape
[
dim
],
B
.
shape
[
dim
],
C
.
shape
[
dim
]
if
(
A
.
shape
[
1
]
==
8
and
A
.
shape
[
0
]
>
512
)
or
(
A
.
shape
[
1
]
==
16
and
A
.
shape
[
0
]
>
256
):
Per_block
=
2
if
(
A
.
shape
[
1
]
==
32
and
A
.
shape
[
2
]
==
512
and
A
.
shape
[
0
]
>
256
):
Per_block
=
8
num_blocks
=
math
.
ceil
(
block_num
/
Per_block
)
concat_kernel
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block
,
block_num
,
output_shape
[
0
],
output_shape
[
1
],
A
.
stride
(
0
),
A
.
stride
(
1
),
A
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
BLOCK_SIZE
=
1024
)
return
C
assert
False
,
"not support"
configs
=
[]
configs
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'M'
,
'N'
],
x_vals
=
[(
4
,
8
),(
8
,
8
),(
16
,
8
),(
32
,
8
),(
64
,
8
),(
96
,
8
),(
128
,
8
),(
256
,
8
),(
512
,
8
),(
768
,
8
),(
1024
,
8
),
\
(
4
,
16
),(
8
,
16
),(
16
,
16
),(
32
,
16
),(
64
,
16
),(
96
,
16
),(
128
,
16
),(
256
,
16
),(
512
,
16
),(
768
,
16
),(
1024
,
16
),
\
(
4
,
32
),(
8
,
32
),(
16
,
32
),(
32
,
32
),(
64
,
32
),(
96
,
32
),(
128
,
32
),(
256
,
32
),(
512
,
32
),(
768
,
32
),(
1024
,
32
)],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'triton'
,
'torch'
],
line_names
=
[
'Triton'
,
'Torch'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
)],
ylabel
=
's'
,
plot_name
=
'concat-dim2'
,
args
=
{
"dim"
:
2
},
),
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark
(
M
,
N
,
provider
,
dim
):
x_sizes
=
[
M
,
N
,
512
]
x_strides
=
[
512
,
512
*
M
,
1
]
x_max_index
=
M
*
N
*
512
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
# print("形状:", x.shape) # [M, 8, 512]
# print("步幅:", x.stride()) # (512, 512*M, 1)
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
1536
*
(
N
//
8
),
192
,
1
]
y_max_index
=
1536
*
(
N
//
8
)
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
# print("形状:", y.shape) # [M, 8, 64]
# print("步幅:", y.stride()) # (1536, 192, 1)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
# @triton.testing.perf_report(configs)
# def benchmark_16(size, provider, dim):
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_32(size, provider, dim):
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_prefill(size, provider, dim):
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
if
__name__
==
'__main__'
:
benchmark
.
run
(
save_path
=
"./triton_test"
,
print_data
=
True
)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
8824ae6a
...
...
@@ -19,6 +19,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadataBuilder
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm
import
envs
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
logger
=
init_logger
(
__name__
)
...
...
@@ -164,8 +166,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
envs
.
VLLM_USE_TRITON_CAT
:
if
q_nope
.
shape
[
0
]
<=
1024
:
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
...
...
vllm/v1/attention/backends/mla/test_concat.py
0 → 100644
View file @
8824ae6a
import
triton
import
triton.language
as
tl
import
torch
from
functools
import
reduce
import
pytest
import
torch
import
math
from
lightop
import
ds_cat
def
test_concat_Acc_prefill
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
M
=
shape1
[
0
]
N
=
shape1
[
1
]
x_sizes
=
[
M
,
N
,
128
]
x_strides
=
[
N
//
8
*
2048
,
256
,
1
]
x_max_index
=
N
//
8
*
2048
*
M
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
576
,
0
,
1
]
y_max_index
=
576
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_prefill_helper_Triton
(
x
,
y
,
dim
=
dim
)
result_lightop
=
lightop_concat_prefill_helper
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
# print("精度验证通过")
# print("expected",expected)
# print("result_lightop",result_lightop)
assert
torch
.
allclose
(
result
,
result_lightop
,
rtol
=
1e-5
,
atol
=
1e-5
),
"result_lightop Mismatch Triton error"
assert
torch
.
allclose
(
expected
,
result_lightop
,
rtol
=
1e-5
,
atol
=
1e-5
),
"result_lightop Mismatch torch error"
print
(
"prefill 精度验证通过"
)
def
test_concat_Acc_decode
(
shape_pair
,
dim
):
torch
.
manual_seed
(
1
)
shape1
,
shape2
=
shape_pair
M
=
shape1
[
0
]
N
=
shape1
[
1
]
x_sizes
=
[
M
,
N
,
512
]
x_strides
=
[
512
,
512
*
M
,
1
]
x_max_index
=
M
*
N
*
512
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
# print("形状:", x.shape)
# print("步幅:", x.stride())
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
1536
*
(
N
//
8
),
192
,
1
]
y_max_index
=
1536
*
(
N
//
8
)
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
# print("形状:", y.shape)
# print("步幅:", y.stride())
expected
=
torch
.
cat
([
x
,
y
],
dim
=
dim
)
result
=
concat_helper_decode
(
x
,
y
,
dim
=
dim
)
assert
torch
.
allclose
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-5
),
"Mismatch"
print
(
"decode 精度正常"
)
@
triton
.
jit
def
concat_kernel
(
A_ptr
,
B_ptr
,
C_ptr
,
A_section_numel
,
B_section_numel
,
C_section_numel
,
Per_block_A
,
Per_block_B
,
section_numA
,
section_numB
,
M
,
N
,
Astride_0
,
Astride_1
,
Astride_2
,
Bstride_0
,
Bstride_1
,
Bstride_2
,
BLOCK_SIZE
:
tl
.
constexpr
):
block_idx
=
tl
.
program_id
(
0
)
numA
=
section_numA
//
Per_block_A
if
(
block_idx
<
numA
):
#处理A的block
for
sub_section_index
in
range
(
Per_block_A
):
sub_offset
=
block_idx
*
Per_block_A
+
sub_section_index
if
sub_offset
<=
section_numA
-
1
:
M_idx
=
sub_offset
//
N
N_idx
=
sub_offset
%
N
C_ptr_block_start
=
C_ptr
+
sub_offset
*
C_section_numel
A_ptr_block_start
=
A_ptr
+
M_idx
*
Astride_0
+
N_idx
*
Astride_1
for
offset
in
range
(
0
,
A_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
A_section_numel
val_from_A
=
tl
.
load
(
A_ptr_block_start
+
offset_idx
,
mask
=
mask
)
tl
.
store
(
C_ptr_block_start
+
offset_idx
,
val_from_A
,
mask
=
mask
)
else
:
#处理B的block
#shape是1024*8*64,实际上只有1024 * 64 块数据,开了1024/4=256个线程块来处理。每个线程块处理1块连续的数据
#需要注意C的分块也是有M * N 大小的,而这里只有M大小个线程块,每个线程块需要写入N次数据到C中。
for
sub_section_index
in
range
(
Per_block_B
):
sub_offset
=
(
block_idx
-
numA
)
*
Per_block_B
+
sub_section_index
if
sub_offset
<=
section_numB
-
1
:
C_ptr_block_start
=
C_ptr
+
sub_offset
*
N
*
C_section_numel
B_ptr_block_start
=
B_ptr
+
sub_offset
*
Bstride_0
for
offset
in
range
(
0
,
B_section_numel
,
BLOCK_SIZE
):
offset_idx
=
offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset_idx
<
B_section_numel
val_from_B
=
tl
.
load
(
B_ptr_block_start
+
offset_idx
,
mask
=
mask
)
for
idx
in
range
(
0
,
N
,
1
):
tl
.
store
(
C_ptr_block_start
+
idx
*
C_section_numel
+
A_section_numel
+
offset_idx
,
val_from_B
,
mask
=
mask
)
def
concat_prefill_helper_Triton
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
#128+64=192
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
if
dim
!=
0
:
#分开计算block块A需要
Per_block_A
=
64
Per_block_B
=
1
#128 \64 \192
unit_offset_A
,
unit_offset_B
,
unit_offset_C
=
A
.
shape
[
dim
],
B
.
shape
[
dim
],
C
.
shape
[
dim
]
#A的分块数是:M * N 这里的demo是1024 * 8
block_numA
=
reduce
(
lambda
x
,
y
:
x
*
y
,
output_shape
[:
dim
])
#B的分块数是:M 这里的demo是1024
block_numB
=
output_shape
[
0
]
#A的每个分块可以处理多份数据的读取和写入,这是因为单次的任务量太小。假设这里Per_block = 8 那么A就开启了1024个线程块,每个线程块处理8份数据的读取和写入
#B的每个分块处理1次B的读取和8次C的写入,L2 cache复用率高
block_num
=
block_numA
//
Per_block_A
+
block_numB
//
Per_block_B
num_blocks
=
math
.
ceil
(
block_num
)
concat_kernel
[(
num_blocks
,)](
A
,
B
,
C
,
unit_offset_A
,
unit_offset_B
,
unit_offset_C
,
Per_block_A
,
Per_block_B
,
block_numA
,
block_numB
,
output_shape
[
0
],
output_shape
[
1
],
A
.
stride
(
0
),
A
.
stride
(
1
),
A
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
BLOCK_SIZE
=
1024
)
return
C
assert
False
,
"not support"
def
concat_helper_decode
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
mode
=
0
if
dim
!=
0
:
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
assert
False
,
"not support"
def
lightop_concat_prefill_helper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
mode
=
6
if
dim
!=
0
:
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
assert
False
,
"not support"
configs
=
[]
configs
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'M'
,
'N'
],
x_vals
=
[(
1024
,
8
),(
2048
,
8
),(
3072
,
8
),(
4096
,
8
),(
6144
,
8
),(
8192
,
8
),
\
(
1024
,
16
),(
2048
,
16
),(
3072
,
16
),(
4096
,
16
),(
6144
,
16
),(
8192
,
16
),
\
(
1024
,
32
),(
2048
,
32
),(
3072
,
32
),(
4096
,
32
),(
6144
,
32
),(
8192
,
32
)
],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'triton'
,
'torch'
,
'lightop'
],
line_names
=
[
'Triton'
,
'Torch'
,
'Lightop'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
),
(
'yellow'
,
'-'
)],
ylabel
=
's'
,
plot_name
=
'concat-dim2'
,
args
=
{
"dim"
:
2
},
),
)
configs_decode
=
[]
configs_decode
.
append
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'M'
,
'N'
],
x_vals
=
[(
4
,
8
),(
8
,
8
),(
16
,
8
),(
32
,
8
),(
64
,
8
),(
96
,
8
),(
128
,
8
),(
256
,
8
),(
512
,
8
),(
768
,
8
),(
767
,
8
),(
765
,
8
),(
766
,
8
),
\
(
4
,
16
),(
8
,
16
),(
16
,
16
),(
32
,
16
),(
64
,
16
),(
96
,
16
),(
128
,
16
),(
256
,
16
),(
512
,
16
),(
768
,
16
),(
767
,
16
),(
765
,
16
),(
766
,
16
),
\
(
4
,
32
),(
8
,
32
),(
16
,
32
),(
32
,
32
),(
64
,
32
),(
96
,
32
),(
128
,
32
),(
256
,
32
),(
512
,
32
),(
768
,
32
),(
767
,
32
),(
765
,
32
),(
766
,
32
)],
x_log
=
True
,
line_arg
=
'provider'
,
line_vals
=
[
'lightop'
,
'torch'
],
line_names
=
[
'Lightop'
,
'Torch'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
)],
ylabel
=
's'
,
plot_name
=
'concat-dim2'
,
args
=
{
"dim"
:
2
},
),
)
@
triton
.
testing
.
perf_report
(
configs
)
def
benchmark_prefill
(
M
,
N
,
provider
,
dim
):
x_sizes
=
[
M
,
N
,
128
]
x_strides
=
[
N
//
8
*
2048
,
256
,
1
]
x_max_index
=
N
//
8
*
2048
*
M
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
576
,
0
,
1
]
y_max_index
=
576
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'triton'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_prefill_helper_Triton
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'lightop'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
lightop_concat_prefill_helper
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
@
triton
.
testing
.
perf_report
(
configs_decode
)
def
benchmark_decode
(
M
,
N
,
provider
,
dim
):
x_sizes
=
[
M
,
N
,
512
]
x_strides
=
[
512
,
512
*
M
,
1
]
x_max_index
=
M
*
N
*
512
x_required_length
=
x_max_index
x_data
=
torch
.
arange
(
x_required_length
,
device
=
'cuda'
).
bfloat16
()
x
=
torch
.
as_strided
(
x_data
,
size
=
x_sizes
,
stride
=
x_strides
)
# print("形状:", x.shape)
# print("步幅:", x.stride())
y_sizes
=
[
M
,
N
,
64
]
y_strides
=
[
1536
*
(
N
//
8
),
192
,
1
]
y_max_index
=
1536
*
(
N
//
8
)
*
M
y_required_length
=
y_max_index
y_data
=
torch
.
arange
(
y_required_length
,
device
=
'cuda'
).
bfloat16
()
y
=
torch
.
as_strided
(
y_data
,
size
=
y_sizes
,
stride
=
y_strides
)
# print("形状:", y.shape)
# print("步幅:", y.stride())
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
'torch'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
cat
([
x
,
y
],
dim
=
dim
),
quantiles
=
quantiles
)
if
provider
==
'lightop'
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
concat_helper_decode
(
x
,
y
,
dim
=
dim
),
quantiles
=
quantiles
)
return
(
ms
*
1000
),
(
max_ms
*
1000
),
(
min_ms
*
1000
)
if
__name__
==
'__main__'
:
benchmark_prefill
.
run
(
save_path
=
"./triton_test"
,
print_data
=
True
)
test_concat_Acc_prefill
(((
1024
,
8
,
128
),
(
1024
,
8
,
64
)),
2
)
test_concat_Acc_prefill
(((
2048
,
8
,
128
),
(
2048
,
8
,
64
)),
2
)
test_concat_Acc_prefill
(((
4096
,
8
,
128
),
(
4096
,
8
,
64
)),
2
)
test_concat_Acc_prefill
(((
8192
,
8
,
128
),
(
8192
,
8
,
64
)),
2
)
test_concat_Acc_prefill
(((
1024
,
16
,
128
),
(
1024
,
16
,
64
)),
2
)
test_concat_Acc_prefill
(((
2048
,
16
,
128
),
(
2048
,
16
,
64
)),
2
)
test_concat_Acc_prefill
(((
4096
,
16
,
128
),
(
4096
,
16
,
64
)),
2
)
test_concat_Acc_prefill
(((
8192
,
16
,
128
),
(
8192
,
16
,
64
)),
2
)
test_concat_Acc_prefill
(((
1024
,
32
,
128
),
(
1024
,
32
,
64
)),
2
)
test_concat_Acc_prefill
(((
2048
,
32
,
128
),
(
2048
,
32
,
64
)),
2
)
test_concat_Acc_prefill
(((
4096
,
32
,
128
),
(
4096
,
32
,
64
)),
2
)
test_concat_Acc_prefill
(((
8192
,
32
,
128
),
(
8192
,
32
,
64
)),
2
)
benchmark_decode
.
run
(
save_path
=
"./cat_triton_test"
,
print_data
=
True
)
test_concat_Acc_decode
(((
16
,
8
,
512
),
(
16
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
32
,
8
,
512
),
(
32
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
128
,
8
,
512
),
(
128
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
768
,
8
,
512
),
(
768
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
32
,
16
,
512
),
(
32
,
16
,
64
)),
2
)
test_concat_Acc_decode
(((
32
,
32
,
512
),
(
32
,
32
,
64
)),
2
)
test_concat_Acc_decode
(((
768
,
32
,
512
),
(
768
,
32
,
64
)),
2
)
test_concat_Acc_decode
(((
128
,
32
,
512
),
(
128
,
32
,
64
)),
2
)
test_concat_Acc_decode
(((
512
,
32
,
512
),
(
512
,
32
,
64
)),
2
)
test_concat_Acc_decode
(((
765
,
8
,
512
),
(
765
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
766
,
8
,
512
),
(
766
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
767
,
8
,
512
),
(
767
,
8
,
64
)),
2
)
test_concat_Acc_decode
(((
765
,
16
,
512
),
(
765
,
16
,
64
)),
2
)
test_concat_Acc_decode
(((
766
,
16
,
512
),
(
766
,
16
,
64
)),
2
)
test_concat_Acc_decode
(((
765
,
32
,
512
),
(
765
,
32
,
64
)),
2
)
test_concat_Acc_decode
(((
767
,
32
,
512
),
(
767
,
32
,
64
)),
2
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment