Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0b229519
Commit
0b229519
authored
May 27, 2025
by
王敏
Browse files
[feat]适配sgl moe_fused_gate kernel
parent
1150b65c
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
626 additions
and
16 deletions
+626
-16
CMakeLists.txt
CMakeLists.txt
+2
-1
csrc/moe/moe_fused_gate.cu
csrc/moe/moe_fused_gate.cu
+539
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+10
-1
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+6
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+28
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+34
-13
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+2
-1
No files found.
CMakeLists.txt
View file @
0b229519
...
@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
...
@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set
(
VLLM_MOE_EXT_SRC
set
(
VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu"
)
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/moe_fused_gate.cu"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
)
...
...
csrc/moe/moe_fused_gate.cu
0 → 100644
View file @
0b229519
This diff is collapsed.
Click to expand it.
csrc/moe/moe_ops.h
View file @
0b229519
...
@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
...
@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
top_k
,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
top_k
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
#endif
#endif
\ No newline at end of file
std
::
vector
<
torch
::
Tensor
>
moe_fused_gate
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
);
\ No newline at end of file
csrc/moe/torch_bindings.cpp
View file @
0b229519
...
@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"
);
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"sgl_moe_align_block_size"
,
torch
::
kCUDA
,
&
sgl_moe_align_block_size
);
m
.
impl
(
"sgl_moe_align_block_size"
,
torch
::
kCUDA
,
&
sgl_moe_align_block_size
);
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
#ifndef USE_ROCM
#ifndef USE_ROCM
m
.
def
(
m
.
def
(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
...
...
vllm/_custom_ops.py
View file @
0b229519
...
@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
...
@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# seq_lens, page_table, scale)
# return out
# return out
def
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
=
0
,
routed_scaling_factor
=
0
,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return
torch
.
ops
.
_moe_C
.
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
0b229519
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
functools
import
functools
import
json
import
json
import
os
import
os
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -1182,6 +1183,10 @@ def fused_topk(
...
@@ -1182,6 +1183,10 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
is_power_of_two
(
n
):
return
n
>
0
and
math
.
log2
(
n
).
is_integer
()
# This is used by the Deepseek-V2 and Deepseek-V3 model
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
grouped_topk
(
def
grouped_topk
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
0b229519
...
@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
fused_experts
from
.fused_moe
import
fused_experts
...
@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
if
hasattr
(
self
,
"routed_scaling_factor"
)
else
None
)
return
fused_experts
(
return
fused_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module):
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
activation
=
activation
self
.
activation
=
activation
self
.
routed_scaling_factor
=
routed_scaling_factor
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
raise
ValueError
(
"Only softmax scoring function is supported for "
...
@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
setattr
(
self
.
quant_method
,
"routed_scaling_factor"
,
self
.
routed_scaling_factor
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
...
@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module):
...
@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
fused_topk
,
grouped_topk
,
is_power_of_two
)
# DeekSeekv2 uses grouped_top_k
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
if
e_score_correction_bias
is
not
None
\
hidden_states
=
hidden_states
,
and
router_logits
.
shape
[
1
]
//
num_expert_group
<=
32
\
gating_output
=
router_logits
,
and
is_power_of_two
(
e_score_correction_bias
.
shape
[
0
]):
topk
=
top_k
,
renormalize
=
renormalize
,
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
num_expert_group
=
num_expert_group
,
topk_weights
,
topk_ids
=
ops
.
moe_fused_gate
(
topk_group
=
topk_group
,
router_logits
,
scoring_func
=
scoring_func
,
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
)
num_expert_group
,
topk_group
,
top_k
,
routed_scaling_factor
=
routed_scaling_factor
,
n_share_experts_fusion
=
0
,
)
else
:
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
=
self
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
use_nn_moe
=
self
.
use_nn_moe
)
)
if
self
.
dp_size
>
1
:
if
self
.
dp_size
>
1
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
0b229519
...
@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,)
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment