Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d2cac26
Commit
8d2cac26
authored
Sep 24, 2025
by
zhuwenwen
Browse files
[kernel] add lightop's moe_sum(mul+add) fusion operator for deepseek
[FIX] 修复mtp和VLLM_USE_TRITON_CAT不能一起开的bug
parent
5086453d
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
134 additions
and
50 deletions
+134
-50
vllm/envs.py
vllm/envs.py
+7
-7
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+63
-13
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+21
-10
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+3
-2
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+3
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+17
-11
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+5
-3
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+5
-3
vllm/v1/attention/backends/mla/test_concat.py
vllm/v1/attention/backends/mla/test_concat.py
+4
-1
No files found.
vllm/envs.py
View file @
8d2cac26
...
...
@@ -164,10 +164,10 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
VLLM_USE_APEX_RN
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_LIGHT_OP
:
bool
=
False
VLLM_USE_TRITON_CAT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
VLLM_USE_LIGHTOP
:
bool
=
False
VLLM_USE_OPT_CAT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1095,12 +1095,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
"VLLM_USE_LIGHT
_
OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHT
_
OP"
,
"
Tru
e"
).
lower
()
in
"VLLM_USE_LIGHTOP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
"VLLM_USE_
TRITON
_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_
TRITON
_CAT"
,
"True"
).
lower
()
in
"VLLM_USE_
OPT
_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_
OPT
_CAT"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
8d2cac26
...
...
@@ -43,9 +43,17 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
if
envs
.
VLLM_USE_LIGHTOP
:
from
lightop
import
op
os
.
environ
[
'DPSK_FP16_QUICK'
]
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
,
'0'
)
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
logger
=
init_logger
(
__name__
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
moe_cache_singleton
=
None
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
global
moe_cache_singleton
if
moe_cache_singleton
is
None
:
...
...
@@ -1257,13 +1265,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
)
def
inplace_fused_experts_fake
(
...
...
@@ -1289,7 +1299,9 @@ def inplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
)
->
None
:
pass
...
...
@@ -1325,14 +1337,16 @@ def outplace_fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
)
def
outplace_fused_experts_fake
(
...
...
@@ -1357,7 +1371,9 @@ def outplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -1414,7 +1430,9 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N
=
w1
.
size
(
1
)
...
...
@@ -1472,7 +1490,9 @@ def fused_experts(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
def
fused_experts_impl
(
...
...
@@ -1500,6 +1520,8 @@ def fused_experts_impl(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_tokens
=
hidden_states
.
size
(
0
)
if
use_nn_moe
:
...
...
@@ -1544,7 +1566,9 @@ def fused_experts_impl(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
...
...
@@ -1571,7 +1595,9 @@ def fused_experts_impl(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
#
...
...
@@ -1744,9 +1770,29 @@ def fused_experts_impl(
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
if
envs
.
VLLM_USE_LIGHTOP
and
not
dpsk_fp16_quick
:
if
shared_output
is
not
None
:
op
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
shared_output
[
begin_chunk_idx
:
end_chunk_idx
],
routed_scaling_factor
)
# else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx])
# if shared_output is not None:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# else:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
...
...
@@ -1779,6 +1825,8 @@ def fused_moe(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
...
...
@@ -1864,7 +1912,9 @@ def fused_moe(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
8d2cac26
...
...
@@ -42,7 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
,
has_deep_ep
,
has_pplx
from
vllm
import
_custom_ops
as
ops
from
lightop
import
op
if
envs
.
VLLM_USE_LIGHTOP
:
from
lightop
import
op
as
op
if
current_platform
.
is_cuda_alike
():
from
.fused_batched_moe
import
BatchedTritonExperts
...
...
@@ -222,6 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -373,6 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
...
...
@@ -397,6 +401,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
shared_output
=
shared_output
,
use_nn_moe
=
use_nn_moe
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
...
...
@@ -418,6 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
...
...
@@ -460,7 +466,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
def
forward_cpu
(
...
...
@@ -1278,7 +1286,7 @@ class FusedMoE(torch.nn.Module):
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
if
use_fused_gate
:
if
envs
.
VLLM_USE_LIGHT
_
OP
:
if
envs
.
VLLM_USE_LIGHTOP
:
topk_weights
,
topk_ids
=
op
.
moe_fused_gate
(
router_logits
,
e_score_correction_bias
,
...
...
@@ -1427,13 +1435,14 @@ class FusedMoE(torch.nn.Module):
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if
current_platform
.
is_tpu
():
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
shared_output
,
self
.
layer_name
)
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
...
...
@@ -1513,7 +1522,8 @@ class FusedMoE(torch.nn.Module):
return
full_final_hidden_states
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
self
.
quant_method
is
not
None
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
):
...
...
@@ -1547,6 +1557,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view
=
self
.
expert_load_view
,
logical_to_physical_map
=
self
.
logical_to_physical_map
,
logical_replica_count
=
self
.
logical_replica_count
,
shared_output
=
shared_output
,
use_nn_moe
=
self
.
use_nn_moe
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
...
...
@@ -1620,16 +1631,16 @@ class FusedMoE(torch.nn.Module):
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
shared_output
)
def
moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
8d2cac26
...
...
@@ -9,7 +9,8 @@ from vllm.triton_utils import tl, triton
from
vllm.utils
import
cdiv
,
round_up
import
vllm.envs
as
envs
from
lightop
import
op
if
envs
.
VLLM_USE_LIGHTOP
:
from
lightop
import
op
as
op
@
triton
.
jit
...
...
@@ -232,7 +233,7 @@ def moe_align_block_size(
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
envs
.
VLLM_USE_LIGHT
_
OP
:
if
envs
.
VLLM_USE_LIGHTOP
:
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
None
)
else
:
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
8d2cac26
...
...
@@ -230,6 +230,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -272,4 +273,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
vllm/model_executor/model_loader/utils.py
View file @
8d2cac26
...
...
@@ -238,10 +238,16 @@ def get_model_architecture(
os
.
environ
[
'LLAMA_NN'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
(
architectures
==
[
'BloomForCausalLM'
]
or
architectures
==
[
'FalconForCausalLM'
])
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
os
.
environ
[
'LM_NN'
]
=
'0'
else
:
os
.
environ
[
'LM_NN'
]
=
'1'
if
(
architectures
==
[
'DeepseekV3ForCausalLM'
]
or
architectures
==
[
'DeepSeekMTPModel'
]):
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP"
):
os
.
environ
[
'VLLM_USE_LIGHTOP'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
8d2cac26
...
...
@@ -213,6 +213,12 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP
and
not
self
.
dpsk_fp16_quick
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
8d2cac26
...
...
@@ -216,7 +216,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
if
envs
.
VLLM_USE_OPT_CAT
:
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -928,7 +930,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_
TRITON
_CAT
:
if
envs
.
VLLM_USE_
OPT
_CAT
:
if
k_nope
.
shape
[
0
]
>
1024
:
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
...
...
@@ -993,7 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_
TRITON
_CAT
:
if
envs
.
VLLM_USE_
OPT
_CAT
:
if
k_nope
.
shape
[
0
]
>
1024
:
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
8d2cac26
...
...
@@ -20,7 +20,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm
import
envs
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
if
envs
.
VLLM_USE_OPT_CAT
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
logger
=
init_logger
(
__name__
)
...
...
@@ -166,8 +168,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
if
envs
.
VLLM_USE_
TRITON
_CAT
:
if
q_nope
.
shape
[
0
]
<
=
1024
:
if
envs
.
VLLM_USE_
OPT
_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
...
...
vllm/v1/attention/backends/mla/test_concat.py
View file @
8d2cac26
...
...
@@ -5,7 +5,10 @@ from functools import reduce
import
pytest
import
torch
import
math
from
lightop
import
ds_cat
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_LIGHTOP
:
from
lightop
import
ds_cat
def
test_concat_Acc_prefill
(
shape_pair
,
dim
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment