Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e2ac05c
Commit
1e2ac05c
authored
Mar 12, 2026
by
chenhw5
Browse files
add DeepGEMM SBO for DeepEP LL
parent
18459e7a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
251 additions
and
47 deletions
+251
-47
vllm/envs.py
vllm/envs.py
+11
-0
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+34
-12
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+140
-14
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
...r/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+2
-0
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+1
-0
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+55
-21
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
.../model_executor/layers/fused_moe/mori_prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+1
-0
No files found.
vllm/envs.py
View file @
1e2ac05c
...
...
@@ -271,6 +271,8 @@ if TYPE_CHECKING:
VLLM_HAS_CONTEXT_DEFAULT
:
bool
=
False
VLLM_USE_NN
:
bool
=
False
VLLM_ENABLE_TBO
:
bool
=
False
# Whether to use single batch overlapping (SBO) for MoE with DeepEP low-latency.
VLLM_EP_USE_SBO
:
bool
=
False
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
VLLM_USE_APEX_RN
:
bool
=
False
...
...
@@ -1229,6 +1231,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DEEPEPLL_NVFP4_DISPATCH"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_DEEPEPLL_NVFP4_DISPATCH"
,
"0"
))
),
# Whether to use single batch overlapping optimization
"VLLM_EP_USE_SBO"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_EP_USE_SBO"
,
"0"
))),
# Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE"
:
lambda
:
os
.
environ
.
get
(
"VLLM_V0_USE_OUTLINES_CACHE"
,
"0"
)
==
"1"
,
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
1e2ac05c
...
...
@@ -34,13 +34,13 @@ from vllm.utils.math_utils import cdiv, round_up
from
vllm.utils.import_utils
import
has_deep_gemm
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
,
fuse_silu_mul_fp8_quant_ep
if
has_deep_gemm
():
from
deep_gemm
import
m_grouped_w8a8_gemm_nt_masked
from
deep_gemm
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_fp8_gemm_nt_masked
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_fp8_gemm_nt_masked
from
typing
import
Any
logger
=
init_logger
(
__name__
)
...
...
@@ -415,6 +415,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
w2_gemm_overlap_args
:
Any
=
None
,
meta_overlap_args
:
dict
[
str
,
Any
]
|
None
=
None
,
):
assert
expert_tokens_meta
is
not
None
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
...
...
@@ -443,7 +445,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
fp8_
m_grouped_gemm_nt_masked
(
m_grouped_
fp8_
gemm_nt_masked
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
workspace1
,
...
...
@@ -451,20 +453,40 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m
,
)
quant_scale_fmt
=
DeepGemmQuantScaleFMT
.
from_oracle
()
a2q
,
a2q_scale
=
persistent_masked_m_silu_mul_quant
(
workspace1
,
expert_num_tokens
,
quant_scale_fmt
=
quant_scale_fmt
,
# ---- SiLU + quant (对应 SGLang 的 fuse_silu_mul_fp8_quant_ep) ----
# workspace1: [E, max_num_tokens, N],在每个 expert 内做 silu*up 并量化成 fp8
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_fp8_quant_ep
(
input
=
workspace1
,
fp8type
=
0
,
# 和你们 deepgemm 约定一致
tokens_per_expert
=
expert_num_tokens
,
)
fp8_m_grouped_gemm_nt_masked
(
(
a2q
,
a2q_scale
),
# When using SBO, we record event here to indicate
# that the signal tensor and input to deepep ll combine
# are ready
enable_overlap
=
w2_gemm_overlap_args
is
not
None
signal
=
w2_gemm_overlap_args
.
signal
if
enable_overlap
else
None
if
enable_overlap
:
w2_gemm_overlap_args
.
start_event
.
record
()
block_m
,
threshold
=
m_grouped_fp8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
,
enable_overlap
,
signal
,
)
# return meta_overlap_args to DeepEP combine.
if
meta_overlap_args
is
not
None
:
if
block_m
is
not
None
:
meta_overlap_args
[
"block_m"
]
=
block_m
if
threshold
is
not
None
:
meta_overlap_args
[
"threshold"
]
=
threshold
elif
self
.
quant_config
.
use_int8_w8a8
:
m_grouped_w8a8_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
1e2ac05c
...
...
@@ -254,6 +254,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -299,6 +300,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -308,6 +310,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
,
topk_ids
,
num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
1e2ac05c
...
...
@@ -23,6 +23,11 @@ from vllm.v1.worker.ubatching import (
dbo_maybe_run_recv_hook
,
)
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
alt_stream
=
torch
.
cuda
.
Stream
()
logger
=
init_logger
(
__name__
)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
...
...
@@ -31,6 +36,42 @@ DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
logger
=
init_logger
(
__name__
)
@
dataclass
class
CombineOverlapArgs
:
# Whether to use overlap for w2 gemm and deepep ll combine
overlap
:
bool
# we launch deepep ll combine on this stream, which
# is different from the default compute stream
stream
:
torch
.
cuda
.
Stream
# We record this wait even in the compute stream between
# silu_mul_fp4_quantize and w2 gemm.
# And we wait for this even before deepep ll combine on the
# combine stream to ensure signal tensors have been allocated
wait_event
:
torch
.
cuda
.
Event
# Number of CU used for combine kernel, currently hardcoded to be 32
num_sms
:
int
# The signal tensor is shared by the w2 gemm and deepep ll combine.
# w2 gemm atomic_add to the tensor to signal deepep combine can start
# send data
signal
:
torch
.
Tensor
|
None
=
None
#
block_m
:
int
=
64
# Set to the number of CU used by W2 gemm, which is a persistent kernel
# So when all CU has completed the computation for an expert,
# combine kernel can start to send data for this expert
threshold
:
int
=
32
@
dataclass
class
W2GemmOverlapArgs
:
# Number of CU used by W2 gemm
num_sms
:
int
# Same signal tensor mentioned above
signal
:
torch
.
Tensor
# Same as the wait_even in CombineOverlapArgs
start_event
:
torch
.
cuda
.
Event
def
dequant_fp8
(
expert_x_fp8
:
torch
.
Tensor
,
expert_x_scales
:
torch
.
Tensor
...
...
@@ -122,6 +163,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# time. This setting is handled by post_init_setup.
self
.
use_ue8m0_dispatch
=
False
# SBO: DeepEP LL overlap 配置
self
.
combine_overlap_args
:
CombineOverlapArgs
|
None
=
None
self
.
meta_overlap_args
:
dict
[
str
,
Any
]
|
None
=
None
self
.
packed_recv_count
:
torch
.
Tensor
|
None
=
None
def
post_init_setup
(
self
,
fused_experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
):
if
not
fused_experts
.
supports_packed_ue8m0_act_scales
():
# Early exit.
...
...
@@ -247,6 +294,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -317,6 +365,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
self
.
handles
[
a2a_idx
]
=
handle
# We need to pass w2_gemm_overlap_args to moe implementation,
# so return it as an output paramter
w2_gemm_overlap_args
=
self
.
_create_sbo_args
(
local_num_experts
,
a1
.
device
)
return
(
hook
,
lambda
:
self
.
_receiver
(
...
...
@@ -326,6 +378,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1
.
dtype
,
quant_config
,
),
w2_gemm_overlap_args
,
)
def
_receiver
(
...
...
@@ -341,7 +394,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
)
# SBO: save last expert_num_tokens for packed_recv_count,needed by deepep ll combine when use SBO.
self
.
packed_recv_count
=
expert_num_tokens
return
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
None
,
None
def
prepare
(
...
...
@@ -350,15 +404,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
hook
,
receiver
=
self
.
prepare_async
(
hook
,
receiver
,
_
=
self
.
prepare_async
(
a1
,
topk_weights
,
topk_ids
,
num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
...
...
@@ -393,6 +449,28 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_ids
=
self
.
_map_global_to_physical_ids
(
topk_ids
)
# TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook
()
ctx
=
nullcontext
()
if
self
.
combine_overlap_args
is
not
None
:
# For SBO, we need to wait for compute stream
# to have completed signal tensor allocation
self
.
combine_overlap_args
.
stream
.
wait_event
(
self
.
combine_overlap_args
.
wait_event
)
# And we launch ll combine phase 1 in a separate stream
# for overlaping
ctx
=
torch
.
cuda
.
stream
(
self
.
combine_overlap_args
.
stream
)
overlap_args_dict
=
dict
(
overlap
=
self
.
combine_overlap_args
.
overlap
,
packed_recv_count
=
self
.
packed_recv_count
,
comp_signal
=
self
.
combine_overlap_args
.
signal
,
block_m
=
self
.
meta_overlap_args
[
"block_m"
],
threshold
=
self
.
meta_overlap_args
[
"threshold"
],
num_sms
=
self
.
combine_overlap_args
.
num_sms
,
)
else
:
overlap_args_dict
=
{}
with
ctx
:
_
,
_
,
recv_hook
=
self
.
buffer
.
low_latency_combine
(
fused_expert_output
,
combine_topk_ids
,
...
...
@@ -402,8 +480,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
zero_copy
=
False
,
return_recv_hook
=
do_recv_hook
,
out
=
output
,
**
overlap_args_dict
,
)
if
self
.
combine_overlap_args
is
not
None
:
return
recv_hook
,
lambda
:
self
.
_sbo_wait_stream
()
else
:
return
recv_hook
,
lambda
:
None
def
finalize_async
(
...
...
@@ -443,3 +525,47 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
weight_and_reduce_impl
,
do_async
=
False
,
)
def
_create_sbo_args
(
self
,
local_num_experts
:
int
,
device
:
torch
.
device
)
->
W2GemmOverlapArgs
|
None
:
# None when SBO is not enabled
w2_gemm_overlap_args
=
None
self
.
combine_overlap_args
=
None
if
not
envs
.
VLLM_EP_USE_SBO
:
self
.
meta_overlap_args
=
None
return
None
else
:
# SBO enabled
self
.
meta_overlap_args
=
{}
# empty every time, avoid use history args.
total_num_sms
=
torch
.
cuda
.
get_device_properties
(
device
=
device
).
multi_processor_count
communicate_num_sms
=
32
compute_num_sms
=
total_num_sms
-
communicate_num_sms
combine_wait_event
=
torch
.
cuda
.
Event
()
combine_overlap_args
=
CombineOverlapArgs
(
num_sms
=
communicate_num_sms
,
stream
=
alt_stream
,
wait_event
=
combine_wait_event
,
)
combine_signal
=
torch
.
zeros
(
local_num_experts
,
dtype
=
torch
.
uint32
,
device
=
device
)
w2_gemm_overlap_args
=
W2GemmOverlapArgs
(
signal
=
combine_signal
,
start_event
=
combine_wait_event
,
num_sms
=
compute_num_sms
,
)
combine_overlap_args
.
overlap
=
True
combine_overlap_args
.
signal
=
combine_signal
self
.
combine_overlap_args
=
combine_overlap_args
return
w2_gemm_overlap_args
def
_sbo_wait_stream
(
self
)
->
None
:
# When SBO enabled, ll combine phase 2 is still launched
# on the main compute stream, but we need to wait for
# ll combine 1 to complete
if
self
.
combine_overlap_args
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
combine_overlap_args
.
stream
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
View file @
1e2ac05c
...
...
@@ -96,6 +96,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -177,6 +178,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
1e2ac05c
...
...
@@ -530,6 +530,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
1e2ac05c
...
...
@@ -34,6 +34,9 @@ from vllm.v1.worker.ubatching import (
dbo_yield
,
)
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
logger
=
init_logger
(
__name__
)
...
...
@@ -177,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -217,6 +221,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -1059,6 +1064,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
)
->
tuple
[
...
...
@@ -1072,6 +1078,7 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
w2_gemm_overlap_args
=
None
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
...
...
@@ -1089,6 +1096,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
,
topk_ids
,
global_num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
...
...
@@ -1101,6 +1109,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
,
topk_ids
,
global_num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
...
...
@@ -1109,7 +1118,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook
,
receiver
=
(
hook
,
receiver
,
w2_gemm_overlap_args
=
(
prepare_ret
if
isinstance
(
prepare_ret
,
tuple
)
else
(
None
,
prepare_ret
)
)
...
...
@@ -1137,7 +1146,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
return
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
return
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
,
w2_gemm_overlap_args
def
_fused_experts
(
self
,
...
...
@@ -1155,6 +1164,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input
:
bool
,
expert_tokens_meta
:
ExpertTokensMetadata
|
None
,
use_nn_moe
:
bool
|
None
=
False
,
w2_gemm_overlap_args
=
None
,
)
->
torch
.
Tensor
:
_
,
M_full
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
...
...
@@ -1217,6 +1227,28 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_out
,
chunk_idx
,
num_chunks
,
CHUNK_SIZE
,
M_full
)
if
isinstance
(
self
.
fused_experts
,
BatchedDeepGemmExperts
):
self
.
fused_experts
.
apply
(
output
=
c_fused_out
,
hidden_states
=
a1q
[
s
:
e
],
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
[
s
:
e
],
topk_ids
=
topk_ids
[
s
:
e
],
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
_slice_scales
(
a1q_scale
,
s
,
e
),
a2_scale
=
_slice_scales
(
self
.
fused_experts
.
a2_scale
,
s
,
e
),
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_tokens_meta
=
c_expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
,
w2_gemm_overlap_args
=
w2_gemm_overlap_args
,
meta_overlap_args
=
self
.
prepare_finalize
.
meta_overlap_args
,
)
else
:
self
.
fused_experts
.
apply
(
output
=
c_fused_out
,
hidden_states
=
a1q
[
s
:
e
],
...
...
@@ -1365,11 +1397,12 @@ class FusedMoEModularKernel(torch.nn.Module):
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
=
self
.
_prepare
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
,
w2_gemm_overlap_args
=
self
.
_prepare
(
hidden_states
,
topk_weights
,
topk_ids
,
global_num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
)
...
...
@@ -1389,6 +1422,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_tokens_meta
=
expert_tokens_meta
,
use_nn_moe
=
use_nn_moe
,
w2_gemm_overlap_args
=
w2_gemm_overlap_args
,
)
return
self
.
_finalize
(
...
...
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
View file @
1e2ac05c
...
...
@@ -55,6 +55,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
1e2ac05c
...
...
@@ -103,6 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -271,6 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
@@ -280,6 +282,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
,
topk_ids
,
num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
1e2ac05c
...
...
@@ -39,6 +39,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment