Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da364615
Unverified
Commit
da364615
authored
Oct 08, 2025
by
bnellnm
Committed by
GitHub
Oct 08, 2025
Browse files
[Kernels] Modular kernel refactor (#24812)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
f08919b7
Changes
22
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
592 additions
and
490 deletions
+592
-490
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+23
-14
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
...s/kernels/moe/modular_kernel_tools/make_feature_matrix.py
+1
-1
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+12
-14
tests/kernels/moe/test_modular_kernel_combinations.py
tests/kernels/moe/test_modular_kernel_combinations.py
+96
-48
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+5
-10
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
...cutor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
+4
-7
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+21
-36
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+4
-8
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
...r/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+8
-0
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+7
-10
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+37
-53
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+354
-273
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+3
-0
No files found.
tests/kernels/moe/modular_kernel_tools/common.py
View file @
da364615
...
...
@@ -209,18 +209,18 @@ class Config:
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
return
info
.
backend
def
is_valid
(
self
):
def
is_valid
(
self
)
->
tuple
[
bool
,
Optional
[
str
]]
:
# Check prepare-finalize and fused-experts compatibility
if
self
.
is_batched_prepare_finalize
():
if
not
self
.
is_batched_fused_experts
():
return
False
return
False
,
"Mismatched format."
else
:
if
not
self
.
is_standard_fused_experts
():
return
False
return
False
,
"Mismatched format."
use_chunking
=
self
.
fused_moe_chunk_size
is
not
None
if
use_chunking
and
not
self
.
is_fe_supports_chunking
():
return
False
return
False
,
"Chunking not supported."
# Check quantization sanity
if
(
...
...
@@ -229,7 +229,7 @@ class Config:
+
int
(
self
.
quant_block_shape
is
not
None
)
)
>
1
:
# invalid quant config
return
False
return
False
,
f
"Bad quant_config
{
self
.
quant_config
}
."
# check type support
if
self
.
quant_dtype
is
None
:
...
...
@@ -237,34 +237,43 @@ class Config:
self
.
dtype
not
in
self
.
pf_supported_types
()
or
self
.
dtype
not
in
self
.
fe_supported_types
()
):
return
False
return
False
,
(
f
"Unsupported type
{
self
.
dtype
}
not in "
f
"
{
self
.
pf_supported_types
()
}
and "
f
"
{
self
.
fe_supported_types
()
}
."
)
else
:
if
(
self
.
quant_dtype
not
in
self
.
pf_supported_types
()
or
self
.
quant_dtype
not
in
self
.
fe_supported_types
()
):
return
False
return
False
,
(
f
"Unsupported quant type
{
self
.
quant_dtype
}
"
f
"not in
{
self
.
pf_supported_types
()
}
and "
f
"
{
self
.
fe_supported_types
()
}
."
)
# Check block quanization support
is_block_quatized
=
self
.
quant_block_shape
is
not
None
if
is_block_quatized
and
self
.
quant_dtype
is
None
:
return
False
return
False
,
"No block quantization support."
if
is_block_quatized
and
not
self
.
is_block_quant_supported
():
return
False
return
False
,
"Mismatched block quantization support."
# deep_gemm only works with block-quantized
if
self
.
needs_deep_gemm
()
and
not
is_block_quatized
:
return
False
return
False
,
"Needs DeepGEMM but not block quantized."
# Check dependencies (turn into asserts?)
if
self
.
needs_deep_ep
()
and
not
has_deep_ep
():
return
False
return
False
,
"Needs DeepEP, but DeepEP not available."
if
self
.
needs_deep_gemm
()
and
not
has_deep_gemm
():
return
False
return
False
,
"Needs DeepGEMM, but DeepGEMM not available."
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
return
False
return
False
,
"Needs PPLX, but PPLX not available."
return
True
return
True
,
None
@
dataclass
...
...
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
View file @
da364615
...
...
@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
)
success
=
None
if
config
.
is_valid
():
if
config
.
is_valid
()
[
0
]
:
print
(
f
"Running config :
{
config
.
describe
()
}
..."
)
try
:
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
...
...
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
da364615
...
...
@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_prepare_and_finalize
(
FlashInferCutlassMoEPrepareAndFinalize
,
standard_format
,
nvfp4_types
,
nvfp4_types
+
fp8_types
,
blocked_quantization_support
=
True
,
backend
=
None
,
force_multigpu
=
True
,
...
...
@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_experts
(
FlashInferExperts
,
standard_format
,
nvfp4_types
,
nvfp4_types
+
fp8_types
,
blocked_quantization_support
=
True
,
supports_chunking
=
True
,
# Note: this is a hack to get it to run for now
...
...
@@ -274,7 +274,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
)
(
register_experts
(
DeepGemmExperts
,
standard_format
,
...
...
@@ -284,7 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
supports_expert_map
=
True
,
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
),
)
register_experts
(
BatchedTritonOrDeepGemmExperts
,
...
...
@@ -464,7 +462,7 @@ def make_fused_experts(
print
(
f
"Making BatchedTritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonOrDeepGemmExperts
(
**
kwargs
)
elif
fused_experts_type
==
DeepGemmExperts
:
print
(
"Making DeepGemmExperts {quant_config} ..."
)
print
(
f
"Making DeepGemmExperts
{
quant_config
}
..."
)
experts
=
DeepGemmExperts
(
quant_config
)
elif
fused_experts_type
==
TritonExperts
:
kwargs
=
quant_kwargs
...
...
tests/kernels/moe/test_modular_kernel_combinations.py
View file @
da364615
...
...
@@ -5,7 +5,7 @@ import copy
import
textwrap
import
traceback
from
itertools
import
product
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
pytest
import
torch
...
...
@@ -13,10 +13,9 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils
import
cuda_device_count_stateless
,
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
...utils
import
multi_gpu_test
from
.modular_kernel_tools.common
import
(
Config
,
RankTensors
,
...
...
@@ -132,7 +131,8 @@ def rank_worker(
def
run
(
config
:
Config
,
verbose
:
bool
):
assert
config
.
is_valid
()
assert
config
.
is_valid
()[
0
]
assert
not
is_nyi_config
(
config
)
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
...
...
@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
return
not
info
.
supports_expert_map
@
pytest
.
mark
.
parametrize
(
"k"
,
Ks
)
@
pytest
.
mark
.
parametrize
(
"n"
,
Ns
)
@
pytest
.
mark
.
parametrize
(
"e"
,
Es
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPEs
)
@
pytest
.
mark
.
parametrize
(
"quant_config"
,
MK_QUANT_CONFIGS
)
def
generate_valid_test_cases
(
world_size
:
int
,
prepare_finalize_types
)
->
list
[
tuple
[
Any
,
...]]:
cases
=
[]
total
=
0
for
k
,
n
,
e
,
dtype
,
quant_config
,
combination
,
chunk_size
in
product
(
Ks
,
Ns
,
Es
,
DTYPEs
,
MK_QUANT_CONFIGS
,
product
(
prepare_finalize_types
,
MK_FUSED_EXPERT_TYPES
),
FUSED_MOE_CHUNK_SIZEs
,
):
total
=
total
+
1
config
=
Config
(
Ms
=
Ms
,
K
=
k
,
N
=
n
,
E
=
e
,
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
],
fused_experts_type
=
combination
[
1
],
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
# TODO(bnell): figure out how to get verbose flag here.
verbose
=
False
# pytestconfig.getoption('verbose') > 0
valid
,
reason
=
config
.
is_valid
()
if
not
valid
:
if
verbose
:
print
(
f
"Test config
{
config
}
is not valid:
{
reason
}
"
)
continue
if
is_nyi_config
(
config
):
if
verbose
:
print
(
f
"Test config
{
config
}
is nyi."
)
continue
cases
.
append
(
(
k
,
n
,
e
,
dtype
,
quant_config
,
combination
[
0
],
combination
[
1
],
chunk_size
,
world_size
,
)
)
print
(
f
"
{
len
(
cases
)
}
of
{
total
}
valid configs generated."
)
return
cases
@
pytest
.
mark
.
parametrize
(
"combination"
,
product
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size"
,
generate_valid_test_cases
(
world_size
=
2
,
prepare_finalize_types
=
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
),
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
meets_multi_gpu_requirements
def
test_modular_kernel_combinations_multigpu
(
k
:
int
,
...
...
@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
e
:
int
,
dtype
:
torch
.
dtype
,
quant_config
:
Optional
[
TestMoEQuantConfig
],
combination
:
tuple
[
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size
:
Optional
[
int
],
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
,
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
chunk_size
:
Optional
[
int
],
world_size
:
int
,
pytestconfig
,
):
if
cuda_device_count_stateless
()
<
world_size
:
pytest
.
skip
(
f
"Not enough GPUs available to run, got "
f
"
{
cuda_device_count_stateless
()
}
exepected "
f
"
{
world_size
}
."
)
config
=
Config
(
Ms
=
Ms
,
K
=
k
,
...
...
@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
]
,
fused_experts_type
=
combination
[
1
]
,
fused_moe_chunk_size
=
fused_moe_
chunk_size
,
prepare_finalize_type
=
prepare_finalize_type
,
fused_experts_type
=
fused_experts_type
,
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
if
not
config
.
is_valid
():
pytest
.
skip
(
f
"Tests config
{
config
}
is not valid. Skipping ..."
)
if
is_nyi_config
(
config
):
pytest
.
skip
(
f
"Tests config
{
config
}
is nyi. Skipping ..."
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
run
(
config
,
verbosity
>
0
)
@
pytest
.
mark
.
parametrize
(
"k"
,
Ks
)
@
pytest
.
mark
.
parametrize
(
"n"
,
Ns
)
@
pytest
.
mark
.
parametrize
(
"e"
,
Es
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPEs
)
@
pytest
.
mark
.
parametrize
(
"quant_config"
,
MK_QUANT_CONFIGS
)
@
pytest
.
mark
.
parametrize
(
"combination"
,
product
(
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size"
,
generate_valid_test_cases
(
world_size
=
1
,
prepare_finalize_types
=
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
),
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
def
test_modular_kernel_combinations_singlegpu
(
k
:
int
,
n
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
quant_config
:
Optional
[
TestMoEQuantConfig
],
combination
:
tuple
[
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size
:
Optional
[
int
],
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
,
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
chunk_size
:
Optional
[
int
],
world_size
:
int
,
pytestconfig
,
):
...
...
@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
]
,
fused_experts_type
=
combination
[
1
]
,
fused_moe_chunk_size
=
fused_moe_
chunk_size
,
prepare_finalize_type
=
prepare_finalize_type
,
fused_experts_type
=
fused_experts_type
,
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
if
not
config
.
is_valid
():
pytest
.
skip
(
f
"Tests config
{
config
}
is not valid. Skipping ..."
)
if
is_nyi_config
(
config
):
pytest
.
skip
(
f
"Tests config
{
config
}
is nyi. Skipping ..."
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
run
(
config
,
verbosity
>
0
)
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
da364615
...
...
@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
(
a
.
size
(
0
)
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
)
max_num_tokens
=
M
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
apply
(
self
,
...
...
@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
w2
.
size
(
1
)
==
K
E
,
max_num_tokens
,
N
,
K
,
top_k_num
=
self
.
moe_problem_size
(
E
,
max_num_tokens
,
N
,
K
,
_
=
self
.
moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
...
...
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
View file @
da364615
...
...
@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
bte_war
is
not
None
return
bte_war
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if
self
.
allow_deep_gemm
:
assert
self
.
batched_deep_gemm_experts
is
not
None
return
self
.
batched_deep_gemm_experts
.
workspace_shapes
(
a
,
aq
,
M
,
N
,
K
,
...
...
@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
else
:
assert
self
.
batched_triton_experts
is
not
None
return
self
.
batched_triton_experts
.
workspace_shapes
(
a
,
aq
,
M
,
N
,
K
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
da364615
...
...
@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# topk weights and reduction are fused in moe_unpermute cuda kernel
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
return
(
workspace1
,
workspace2
,
output
)
class
CutlassBatchedExpertsFp8
(
CutlassExpertsFp8Base
):
...
...
@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def
supports_expert_map
(
self
)
->
bool
:
return
False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
padded_M
=
aq
.
size
(
1
)
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
assert
num_dp
is
not
None
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
workspace1
=
(
self
.
max_experts_per_worker
,
M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
cutlass_moe_fp8
(
...
...
@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
:
tuple
[
int
,
...]
=
()
workspace2
:
tuple
[
int
,
...]
=
()
output
:
tuple
[
int
,
...]
=
()
if
self
.
use_batched_format
:
padded_M
=
aq
.
size
(
1
)
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
,
(
N
//
2
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
workspace1
=
(
self
.
max_experts_per_worker
,
M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
,
(
N
//
2
))
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
else
:
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace2
=
(
M
*
topk
,
N
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
da364615
...
...
@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
self
.
block_shape
is
not
None
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
...
...
@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
da364615
...
...
@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
da364615
...
...
@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
da364615
...
...
@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
...
...
@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
aq_m
,
aq_n
=
aq
.
shape
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,)
output_shape
=
(
aq_m
,
aq_n
*
2
)
if
self
.
quant_dtype
==
"nvfp4"
else
(
aq_m
,
aq_n
)
workspace_dtype
=
a
.
dtype
workspace1
=
output_shape
output_shape
=
(
M
,
K
*
2
if
self
.
quant_dtype
==
"nvfp4"
else
K
)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return
(
workspace1
,
workspace2
,
output_shape
,
workspace_dtype
)
return
(
workspace1
,
workspace2
,
output_shape
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
View file @
da364615
...
...
@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.utils.flashinfer
import
nvfp4_block_scale_interleave
...
...
@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
_apply_router_weight_on_input
(
self
,
a1
:
torch
.
Tensor
,
...
...
@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
assert
isinstance
(
weight_and_reduce_impl
,
TopKWeightAndReduceNoOP
)
if
self
.
use_dp
:
fused_expert_output
=
get_dp_group
().
reduce_scatterv
(
fused_expert_output
,
dim
=
0
,
sizes
=
get_local_sizes
()
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
da364615
...
...
@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
...
...
@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
workspace13
=
(
num_experts
,
self
.
max_num_tokens
*
num_dp
,
K
)
workspace2
=
(
self
.
max_num_tokens
*
num_dp
,
N
)
output
=
workspace13
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
dequant
(
self
,
t
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
self
.
quant_config
.
is_quantized
...
...
@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
self
.
max_num_tokens
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dp
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dp
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
da364615
...
...
@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as,
...
...
@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
=
(
M
*
topk
*
max
(
2
*
N
,
K
),)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
da364615
...
...
@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
,
topk
,
max
(
N
//
2
,
K
))
workspace2
=
(
M
,
topk
,
max
(
N
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
da364615
...
...
@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# workspace are allocated inside the kernel
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,
0
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
da364615
...
...
@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
->
Optional
[
FusedMoEQuantConfig
]:
raise
NotImplementedError
@
property
def
using_modular_kernel
(
self
)
->
bool
:
return
self
.
fused_experts
is
not
None
@
abstractmethod
def
apply
(
self
,
...
...
@@ -1237,38 +1241,24 @@ class FusedMoE(CustomOp):
self
.
batched_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
self
.
batched_router_logits
:
Optional
[
torch
.
Tensor
]
=
None
# TODO(bnell): flashinfer uses non-batched format.
# Does it really need a batched buffer?
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
self
.
moe_config
.
use_flashinfer_cutlass_kernels
):
if
vllm_config
.
parallel_config
.
enable_dbo
:
self
.
batched_hidden_states
=
torch
.
zeros
(
(
2
,
moe
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
if
self
.
use_dp_chunking
:
states_shape
:
tuple
[
int
,
...]
logits_shape
:
tuple
[
int
,
...]
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
(
2
,
moe
.
max_num_tokens
,
num_experts
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
if
vllm_config
.
parallel_config
.
enable_dbo
:
states_shape
=
(
2
,
moe
.
max_num_tokens
,
self
.
hidden_size
)
logits_shape
=
(
2
,
moe
.
max_num_tokens
,
num_experts
)
else
:
states_shape
=
(
moe
.
max_num_tokens
,
self
.
hidden_size
)
logits_shape
=
(
moe
.
max_num_tokens
,
num_experts
)
self
.
batched_hidden_states
=
torch
.
zeros
(
(
moe
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
states_shape
,
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
()
)
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
(
moe
.
max_num_tokens
,
num_experts
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
logits_shape
,
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
()
)
@
property
...
...
@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
and
self
.
moe_config
.
use_flashinfer_cutlass_kernels
)
@
property
def
use_dp_chunking
(
self
)
->
bool
:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
def
update_expert_map
(
self
):
# ep_size and ep_rank should already be updated
assert
self
.
expert_map
is
not
None
...
...
@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert
self
.
quant_method
is
not
None
return
(
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
or
self
.
use_deepep_ll_kernels
self
.
quant_method
.
fused_experts
is
not
None
and
self
.
quant_method
.
fused_experts
.
output_is_reduced
()
)
def
maybe_all_reduce_tensor_model_parallel
(
self
,
final_hidden_states
:
torch
.
Tensor
):
"""
The pplx
combine kernel reduce
s
across GPU ranks by default.
Some
combine kernel
s
reduce across GPU ranks by default.
"""
if
(
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
or
self
.
use_deepep_ll_kernels
):
if
self
.
must_reduce_shared_expert_outputs
():
return
final_hidden_states
else
:
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
self
.
ensure_moe_quant_config
()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
_use_flashinfer_cutlass_kernels
=
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
_use_flashinfer_cutlass_kernels
):
if
self
.
use_dp_chunking
:
return
self
.
forward_impl_chunked
(
hidden_states
,
router_logits
)
do_naive_dispatch_combine
:
bool
=
(
self
.
dp_size
>
1
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
and
not
self
.
moe_config
.
use_flashinfer_cutlass_kernels
self
.
dp_size
>
1
and
not
self
.
quant_method
.
using_modular_kernel
)
# If there are shared experts but we are not using a modular kernel, the
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
da364615
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
da364615
...
...
@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
def
supports_async
(
self
)
->
bool
:
return
True
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
da364615
...
...
@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
1
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment