Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da364615
Unverified
Commit
da364615
authored
Oct 08, 2025
by
bnellnm
Committed by
GitHub
Oct 08, 2025
Browse files
[Kernels] Modular kernel refactor (#24812)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
f08919b7
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
592 additions
and
490 deletions
+592
-490
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+23
-14
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
...s/kernels/moe/modular_kernel_tools/make_feature_matrix.py
+1
-1
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+12
-14
tests/kernels/moe/test_modular_kernel_combinations.py
tests/kernels/moe/test_modular_kernel_combinations.py
+96
-48
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+5
-10
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
...cutor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
+4
-7
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+21
-36
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+4
-8
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
...r/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+8
-0
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+7
-10
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+37
-53
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+354
-273
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+3
-0
No files found.
tests/kernels/moe/modular_kernel_tools/common.py
View file @
da364615
...
...
@@ -209,18 +209,18 @@ class Config:
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
return
info
.
backend
def
is_valid
(
self
):
def
is_valid
(
self
)
->
tuple
[
bool
,
Optional
[
str
]]
:
# Check prepare-finalize and fused-experts compatibility
if
self
.
is_batched_prepare_finalize
():
if
not
self
.
is_batched_fused_experts
():
return
False
return
False
,
"Mismatched format."
else
:
if
not
self
.
is_standard_fused_experts
():
return
False
return
False
,
"Mismatched format."
use_chunking
=
self
.
fused_moe_chunk_size
is
not
None
if
use_chunking
and
not
self
.
is_fe_supports_chunking
():
return
False
return
False
,
"Chunking not supported."
# Check quantization sanity
if
(
...
...
@@ -229,7 +229,7 @@ class Config:
+
int
(
self
.
quant_block_shape
is
not
None
)
)
>
1
:
# invalid quant config
return
False
return
False
,
f
"Bad quant_config
{
self
.
quant_config
}
."
# check type support
if
self
.
quant_dtype
is
None
:
...
...
@@ -237,34 +237,43 @@ class Config:
self
.
dtype
not
in
self
.
pf_supported_types
()
or
self
.
dtype
not
in
self
.
fe_supported_types
()
):
return
False
return
False
,
(
f
"Unsupported type
{
self
.
dtype
}
not in "
f
"
{
self
.
pf_supported_types
()
}
and "
f
"
{
self
.
fe_supported_types
()
}
."
)
else
:
if
(
self
.
quant_dtype
not
in
self
.
pf_supported_types
()
or
self
.
quant_dtype
not
in
self
.
fe_supported_types
()
):
return
False
return
False
,
(
f
"Unsupported quant type
{
self
.
quant_dtype
}
"
f
"not in
{
self
.
pf_supported_types
()
}
and "
f
"
{
self
.
fe_supported_types
()
}
."
)
# Check block quanization support
is_block_quatized
=
self
.
quant_block_shape
is
not
None
if
is_block_quatized
and
self
.
quant_dtype
is
None
:
return
False
return
False
,
"No block quantization support."
if
is_block_quatized
and
not
self
.
is_block_quant_supported
():
return
False
return
False
,
"Mismatched block quantization support."
# deep_gemm only works with block-quantized
if
self
.
needs_deep_gemm
()
and
not
is_block_quatized
:
return
False
return
False
,
"Needs DeepGEMM but not block quantized."
# Check dependencies (turn into asserts?)
if
self
.
needs_deep_ep
()
and
not
has_deep_ep
():
return
False
return
False
,
"Needs DeepEP, but DeepEP not available."
if
self
.
needs_deep_gemm
()
and
not
has_deep_gemm
():
return
False
return
False
,
"Needs DeepGEMM, but DeepGEMM not available."
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
return
False
return
False
,
"Needs PPLX, but PPLX not available."
return
True
return
True
,
None
@
dataclass
...
...
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
View file @
da364615
...
...
@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
)
success
=
None
if
config
.
is_valid
():
if
config
.
is_valid
()
[
0
]
:
print
(
f
"Running config :
{
config
.
describe
()
}
..."
)
try
:
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
...
...
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
da364615
...
...
@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_prepare_and_finalize
(
FlashInferCutlassMoEPrepareAndFinalize
,
standard_format
,
nvfp4_types
,
nvfp4_types
+
fp8_types
,
blocked_quantization_support
=
True
,
backend
=
None
,
force_multigpu
=
True
,
...
...
@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_experts
(
FlashInferExperts
,
standard_format
,
nvfp4_types
,
nvfp4_types
+
fp8_types
,
blocked_quantization_support
=
True
,
supports_chunking
=
True
,
# Note: this is a hack to get it to run for now
...
...
@@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
)
(
register_experts
(
DeepGemmExperts
,
standard_format
,
fp8_types
,
blocked_quantization_support
=
True
,
supports_chunking
=
True
,
supports_expert_map
=
True
,
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
),
register_experts
(
DeepGemmExperts
,
standard_format
,
fp8_types
,
blocked_quantization_support
=
True
,
supports_chunking
=
True
,
supports_expert_map
=
True
,
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
)
register_experts
(
BatchedTritonOrDeepGemmExperts
,
...
...
@@ -464,7 +462,7 @@ def make_fused_experts(
print
(
f
"Making BatchedTritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonOrDeepGemmExperts
(
**
kwargs
)
elif
fused_experts_type
==
DeepGemmExperts
:
print
(
"Making DeepGemmExperts {quant_config} ..."
)
print
(
f
"Making DeepGemmExperts
{
quant_config
}
..."
)
experts
=
DeepGemmExperts
(
quant_config
)
elif
fused_experts_type
==
TritonExperts
:
kwargs
=
quant_kwargs
...
...
tests/kernels/moe/test_modular_kernel_combinations.py
View file @
da364615
...
...
@@ -5,7 +5,7 @@ import copy
import
textwrap
import
traceback
from
itertools
import
product
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
pytest
import
torch
...
...
@@ -13,10 +13,9 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils
import
cuda_device_count_stateless
,
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
...utils
import
multi_gpu_test
from
.modular_kernel_tools.common
import
(
Config
,
RankTensors
,
...
...
@@ -132,7 +131,8 @@ def rank_worker(
def
run
(
config
:
Config
,
verbose
:
bool
):
assert
config
.
is_valid
()
assert
config
.
is_valid
()[
0
]
assert
not
is_nyi_config
(
config
)
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
...
...
@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
return
not
info
.
supports_expert_map
@
pytest
.
mark
.
parametrize
(
"k"
,
Ks
)
@
pytest
.
mark
.
parametrize
(
"n"
,
Ns
)
@
pytest
.
mark
.
parametrize
(
"e"
,
Es
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPEs
)
@
pytest
.
mark
.
parametrize
(
"quant_config"
,
MK_QUANT_CONFIGS
)
def
generate_valid_test_cases
(
world_size
:
int
,
prepare_finalize_types
)
->
list
[
tuple
[
Any
,
...]]:
cases
=
[]
total
=
0
for
k
,
n
,
e
,
dtype
,
quant_config
,
combination
,
chunk_size
in
product
(
Ks
,
Ns
,
Es
,
DTYPEs
,
MK_QUANT_CONFIGS
,
product
(
prepare_finalize_types
,
MK_FUSED_EXPERT_TYPES
),
FUSED_MOE_CHUNK_SIZEs
,
):
total
=
total
+
1
config
=
Config
(
Ms
=
Ms
,
K
=
k
,
N
=
n
,
E
=
e
,
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
],
fused_experts_type
=
combination
[
1
],
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
# TODO(bnell): figure out how to get verbose flag here.
verbose
=
False
# pytestconfig.getoption('verbose') > 0
valid
,
reason
=
config
.
is_valid
()
if
not
valid
:
if
verbose
:
print
(
f
"Test config
{
config
}
is not valid:
{
reason
}
"
)
continue
if
is_nyi_config
(
config
):
if
verbose
:
print
(
f
"Test config
{
config
}
is nyi."
)
continue
cases
.
append
(
(
k
,
n
,
e
,
dtype
,
quant_config
,
combination
[
0
],
combination
[
1
],
chunk_size
,
world_size
,
)
)
print
(
f
"
{
len
(
cases
)
}
of
{
total
}
valid configs generated."
)
return
cases
@
pytest
.
mark
.
parametrize
(
"combination"
,
product
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size"
,
generate_valid_test_cases
(
world_size
=
2
,
prepare_finalize_types
=
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
),
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
meets_multi_gpu_requirements
def
test_modular_kernel_combinations_multigpu
(
k
:
int
,
...
...
@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
e
:
int
,
dtype
:
torch
.
dtype
,
quant_config
:
Optional
[
TestMoEQuantConfig
],
combination
:
tuple
[
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size
:
Optional
[
int
],
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
,
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
chunk_size
:
Optional
[
int
],
world_size
:
int
,
pytestconfig
,
):
if
cuda_device_count_stateless
()
<
world_size
:
pytest
.
skip
(
f
"Not enough GPUs available to run, got "
f
"
{
cuda_device_count_stateless
()
}
exepected "
f
"
{
world_size
}
."
)
config
=
Config
(
Ms
=
Ms
,
K
=
k
,
...
...
@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
]
,
fused_experts_type
=
combination
[
1
]
,
fused_moe_chunk_size
=
fused_moe_
chunk_size
,
prepare_finalize_type
=
prepare_finalize_type
,
fused_experts_type
=
fused_experts_type
,
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
if
not
config
.
is_valid
():
pytest
.
skip
(
f
"Tests config
{
config
}
is not valid. Skipping ..."
)
if
is_nyi_config
(
config
):
pytest
.
skip
(
f
"Tests config
{
config
}
is nyi. Skipping ..."
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
run
(
config
,
verbosity
>
0
)
@
pytest
.
mark
.
parametrize
(
"k"
,
Ks
)
@
pytest
.
mark
.
parametrize
(
"n"
,
Ns
)
@
pytest
.
mark
.
parametrize
(
"e"
,
Es
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPEs
)
@
pytest
.
mark
.
parametrize
(
"quant_config"
,
MK_QUANT_CONFIGS
)
@
pytest
.
mark
.
parametrize
(
"combination"
,
product
(
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size"
,
generate_valid_test_cases
(
world_size
=
1
,
prepare_finalize_types
=
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
),
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
def
test_modular_kernel_combinations_singlegpu
(
k
:
int
,
n
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
quant_config
:
Optional
[
TestMoEQuantConfig
],
combination
:
tuple
[
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size
:
Optional
[
int
],
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
,
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
chunk_size
:
Optional
[
int
],
world_size
:
int
,
pytestconfig
,
):
...
...
@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
]
,
fused_experts_type
=
combination
[
1
]
,
fused_moe_chunk_size
=
fused_moe_
chunk_size
,
prepare_finalize_type
=
prepare_finalize_type
,
fused_experts_type
=
fused_experts_type
,
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
if
not
config
.
is_valid
():
pytest
.
skip
(
f
"Tests config
{
config
}
is not valid. Skipping ..."
)
if
is_nyi_config
(
config
):
pytest
.
skip
(
f
"Tests config
{
config
}
is nyi. Skipping ..."
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
run
(
config
,
verbosity
>
0
)
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
da364615
...
...
@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
(
a
.
size
(
0
)
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
)
max_num_tokens
=
M
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
apply
(
self
,
...
...
@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
w2
.
size
(
1
)
==
K
E
,
max_num_tokens
,
N
,
K
,
top_k_num
=
self
.
moe_problem_size
(
E
,
max_num_tokens
,
N
,
K
,
_
=
self
.
moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
...
...
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
View file @
da364615
...
...
@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
bte_war
is
not
None
return
bte_war
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if
self
.
allow_deep_gemm
:
assert
self
.
batched_deep_gemm_experts
is
not
None
return
self
.
batched_deep_gemm_experts
.
workspace_shapes
(
a
,
aq
,
M
,
N
,
K
,
...
...
@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
else
:
assert
self
.
batched_triton_experts
is
not
None
return
self
.
batched_triton_experts
.
workspace_shapes
(
a
,
aq
,
M
,
N
,
K
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
da364615
...
...
@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# topk weights and reduction are fused in moe_unpermute cuda kernel
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
return
(
workspace1
,
workspace2
,
output
)
class
CutlassBatchedExpertsFp8
(
CutlassExpertsFp8Base
):
...
...
@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def
supports_expert_map
(
self
)
->
bool
:
return
False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
padded_M
=
aq
.
size
(
1
)
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
assert
num_dp
is
not
None
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
workspace1
=
(
self
.
max_experts_per_worker
,
M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
cutlass_moe_fp8
(
...
...
@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
:
tuple
[
int
,
...]
=
()
workspace2
:
tuple
[
int
,
...]
=
()
output
:
tuple
[
int
,
...]
=
()
if
self
.
use_batched_format
:
padded_M
=
aq
.
size
(
1
)
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
,
(
N
//
2
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
workspace1
=
(
self
.
max_experts_per_worker
,
M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
,
(
N
//
2
))
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
else
:
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace2
=
(
M
*
topk
,
N
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
da364615
...
...
@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
self
.
block_shape
is
not
None
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
...
...
@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
da364615
...
...
@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
da364615
...
...
@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
da364615
...
...
@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
...
...
@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
aq_m
,
aq_n
=
aq
.
shape
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,)
output_shape
=
(
aq_m
,
aq_n
*
2
)
if
self
.
quant_dtype
==
"nvfp4"
else
(
aq_m
,
aq_n
)
workspace_dtype
=
a
.
dtype
workspace1
=
output_shape
output_shape
=
(
M
,
K
*
2
if
self
.
quant_dtype
==
"nvfp4"
else
K
)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return
(
workspace1
,
workspace2
,
output_shape
,
workspace_dtype
)
return
(
workspace1
,
workspace2
,
output_shape
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
View file @
da364615
...
...
@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.utils.flashinfer
import
nvfp4_block_scale_interleave
...
...
@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
_apply_router_weight_on_input
(
self
,
a1
:
torch
.
Tensor
,
...
...
@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
assert
isinstance
(
weight_and_reduce_impl
,
TopKWeightAndReduceNoOP
)
if
self
.
use_dp
:
fused_expert_output
=
get_dp_group
().
reduce_scatterv
(
fused_expert_output
,
dim
=
0
,
sizes
=
get_local_sizes
()
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
da364615
...
...
@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
...
...
@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
workspace13
=
(
num_experts
,
self
.
max_num_tokens
*
num_dp
,
K
)
workspace2
=
(
self
.
max_num_tokens
*
num_dp
,
N
)
output
=
workspace13
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
dequant
(
self
,
t
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
self
.
quant_config
.
is_quantized
...
...
@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
self
.
max_num_tokens
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dp
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dp
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
da364615
...
...
@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as,
...
...
@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
=
(
M
*
topk
*
max
(
2
*
N
,
K
),)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
da364615
...
...
@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
,
topk
,
max
(
N
//
2
,
K
))
workspace2
=
(
M
,
topk
,
max
(
N
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
da364615
...
...
@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# workspace are allocated inside the kernel
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,
0
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
da364615
...
...
@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
->
Optional
[
FusedMoEQuantConfig
]:
raise
NotImplementedError
@
property
def
using_modular_kernel
(
self
)
->
bool
:
return
self
.
fused_experts
is
not
None
@
abstractmethod
def
apply
(
self
,
...
...
@@ -1237,39 +1241,25 @@ class FusedMoE(CustomOp):
self
.
batched_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
self
.
batched_router_logits
:
Optional
[
torch
.
Tensor
]
=
None
# TODO(bnell): flashinfer uses non-batched format.
# Does it really need a batched buffer?
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
self
.
moe_config
.
use_flashinfer_cutlass_kernels
):
if
vllm_config
.
parallel_config
.
enable_dbo
:
self
.
batched_hidden_states
=
torch
.
zeros
(
(
2
,
moe
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
if
self
.
use_dp_chunking
:
states_shape
:
tuple
[
int
,
...]
logits_shape
:
tuple
[
int
,
...]
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
(
2
,
moe
.
max_num_tokens
,
num_experts
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
# Note here we use `num_experts` which is logical expert count
if
vllm_config
.
parallel_config
.
enable_dbo
:
states_shape
=
(
2
,
moe
.
max_num_tokens
,
self
.
hidden_size
)
logits_shape
=
(
2
,
moe
.
max_num_tokens
,
num_experts
)
else
:
self
.
batched_hidden_states
=
torch
.
zeros
(
(
moe
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
states_shape
=
(
moe
.
max_num_tokens
,
self
.
hidden_size
)
logits_shape
=
(
moe
.
max_num_tokens
,
num_experts
)
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
(
moe
.
max_num_tokens
,
num_experts
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
self
.
batched_hidden_states
=
torch
.
zeros
(
states_shape
,
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
()
)
self
.
batched_router_logits
=
torch
.
zeros
(
logits_shape
,
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
()
)
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
...
...
@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
and
self
.
moe_config
.
use_flashinfer_cutlass_kernels
)
@
property
def
use_dp_chunking
(
self
)
->
bool
:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
def
update_expert_map
(
self
):
# ep_size and ep_rank should already be updated
assert
self
.
expert_map
is
not
None
...
...
@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert
self
.
quant_method
is
not
None
return
(
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
or
self
.
use_deepep_ll_kernels
self
.
quant_method
.
fused_experts
is
not
None
and
self
.
quant_method
.
fused_experts
.
output_is_reduced
()
)
def
maybe_all_reduce_tensor_model_parallel
(
self
,
final_hidden_states
:
torch
.
Tensor
):
"""
The pplx
combine kernel reduce
s
across GPU ranks by default.
Some
combine kernel
s
reduce across GPU ranks by default.
"""
if
(
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
or
self
.
use_deepep_ll_kernels
):
if
self
.
must_reduce_shared_expert_outputs
():
return
final_hidden_states
else
:
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
self
.
ensure_moe_quant_config
()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
_use_flashinfer_cutlass_kernels
=
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
_use_flashinfer_cutlass_kernels
):
if
self
.
use_dp_chunking
:
return
self
.
forward_impl_chunked
(
hidden_states
,
router_logits
)
do_naive_dispatch_combine
:
bool
=
(
self
.
dp_size
>
1
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
and
not
self
.
moe_config
.
use_flashinfer_cutlass_kernels
self
.
dp_size
>
1
and
not
self
.
quant_method
.
using_modular_kernel
)
# If there are shared experts but we are not using a modular kernel, the
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
da364615
...
...
@@ -337,6 +337,14 @@ class FusedMoEPrepareAndFinalize(ABC):
def
num_dispatchers
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
output_is_reduced
(
self
)
->
bool
:
"""
Indicates whether or not the output of finalize is reduced across all
ranks.
"""
raise
NotImplementedError
# TODO: add supported activations method (return string)
class
FusedMoEPermuteExpertsUnpermute
(
ABC
):
...
...
@@ -493,11 +501,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise
NotImplementedError
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
"""
Workspace type: The dtype to use for the workspace tensors.
"""
return
act_dtype
@
abstractmethod
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
...
...
@@ -505,22 +517,33 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Inputs:
- M: number of tokens.
- N: Row (or column) dimension of expert weights.
- K: hidden dimension
- topk: The number of top-k experts to select.
- global_num_experts: global number of experts.
- local_num_experts: local number of experts due to DP/EP.
- expert_tokens_meta: number of tokens per expert metadata for batched
format.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
- Note: workspace shapes can be 0 if the workspace is not needed.
But in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens when the shape is
not 0.
"""
raise
NotImplementedError
...
...
@@ -561,7 +584,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
):
)
->
None
:
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
...
...
@@ -600,7 +623,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise
NotImplementedError
def
_
chunk
_scales
(
def
_
slice
_scales
(
scales
:
Optional
[
torch
.
Tensor
],
start
:
int
,
end
:
int
)
->
Optional
[
torch
.
Tensor
]:
if
scales
is
not
None
:
...
...
@@ -615,9 +638,10 @@ class SharedResizableBuffer:
def
__init__
(
self
):
self
.
buffer
=
None
def
get
(
self
,
shape
:
tuple
[
int
,
...],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
if
shape
==
()
or
shape
is
None
:
return
None
def
get
(
self
,
shape
:
tuple
[
int
,
...],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
assert
shape
!=
()
shape_numel
=
prod
(
shape
)
if
(
self
.
buffer
is
None
...
...
@@ -678,131 +702,74 @@ class FusedMoEModularKernel(torch.nn.Module):
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
)
def
_do_fused_experts
(
def
output_is_reduced
(
self
)
->
bool
:
"""
Indicates whether or not the output of fused MoE kernel
is reduced across all ranks.
"""
return
self
.
prepare_finalize
.
output_is_reduced
()
def
_chunk_info
(
self
,
M
:
int
)
->
tuple
[
int
,
int
]:
"""
Compute number of chunks and chunk size for given M.
If chunking is not supported, set the CHUNK_SIZE to M so we
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
If there are no tokens to process, the number of chunks will be zero.
"""
CHUNK_SIZE
=
(
max
(
M
,
1
)
if
not
self
.
fused_experts
.
supports_chunking
()
else
min
(
M
,
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
)
)
num_chunks
=
cdiv
(
M
,
CHUNK_SIZE
)
# If there are no tokens, then there should be no loop iterations.
assert
M
>
0
or
num_chunks
==
0
return
num_chunks
,
CHUNK_SIZE
def
_allocate_buffers
(
self
,
fused_out
:
Optional
[
torch
.
Tensor
],
a1
:
torch
.
Tensor
,
a1q
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
out_dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
M_chunk
:
int
,
M_full
:
int
,
N
:
int
,
K
:
int
,
top_k
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
)
->
torch
.
Tensor
:
_
,
M
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Allocate temporary and output buffers for the fused experts op.
Inputs:
- out_dtype: output type of workspace and output tensors.
- device: the device of the workspace and output tensors.
See `workspace_shapes` for a description of the remainder of arguments.
Returns a tuple of (workspace13, workspace2, output) tensors.
"""
assert
M_full
>
0
and
M_chunk
>
0
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
(
self
.
fused_experts
.
workspace_shapes
(
a1
,
a1q
,
M
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
)
num_chunks
,
_
=
self
.
_chunk_info
(
M_full
)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx
=
dbo_current_ubatch_id
()
buffers
=
self
.
shared_buffers
[
ubatch_idx
]
workspace_dtype
=
self
.
fused_experts
.
workspace_dtype
(
out_dtype
)
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13
=
buffers
.
workspace13
.
get
(
workspace13_shape
,
device
=
a1
.
device
,
dtype
=
workspace_dtype
)
workspace2
=
buffers
.
workspace2
.
get
(
workspace2_shape
,
device
=
a1
.
device
,
dtype
=
workspace_dtype
)
assert
fused_out
is
None
or
fused_out
.
shape
==
fused_out_shape
,
(
f
"fused_out
{
fused_out
.
shape
}
but expected
{
fused_out_shape
}
"
)
if
fused_out
is
None
:
# reuse workspace13 for the output
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
self
.
fused_experts
.
apply
(
fused_out
,
a1q
,
w1
,
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_tokens_meta
=
expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
# Get intermediate workspace shapes based off the chunked M size.
workspace13_shape
,
workspace2_shape
,
_
=
self
.
fused_experts
.
workspace_shapes
(
M_chunk
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
return
fused_out
def
_maybe_chunk_fused_experts
(
self
,
a1
:
torch
.
Tensor
,
a1q
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
)
->
torch
.
Tensor
:
_
,
M
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks
=
cdiv
(
M
,
CHUNK_SIZE
)
# TODO(bnell): get rid of one level here, update slice functions
# to nops on num_chunks==1
if
not
self
.
fused_experts
.
supports_chunking
()
or
num_chunks
==
1
:
return
self
.
_do_fused_experts
(
fused_out
=
None
,
a1
=
a1
,
a1q
=
a1q
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
a1q_scale
,
a2_scale
=
self
.
fused_experts
.
a2_scale
,
expert_tokens_meta
=
expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
# Chunking required case
assert
num_chunks
>
1
# Construct the entire output that can then be processed in chunks.
(
_
,
_
,
fused_out_shape
,
_
)
=
self
.
fused_experts
.
workspace_shapes
(
a1
,
a1q
,
M
,
# Get final output shape based on the full M size.
_
,
_
,
fused_out_shape
=
self
.
fused_experts
.
workspace_shapes
(
M_full
,
N
,
K
,
top_k
,
...
...
@@ -810,150 +777,99 @@ class FusedMoEModularKernel(torch.nn.Module):
local_num_experts
,
expert_tokens_meta
,
)
ubatch_idx
=
dbo_current_ubatch_id
()
buffers
=
self
.
shared_buffers
[
ubatch_idx
]
fused_out
=
buffers
.
fused_out
.
get
(
fused_out_shape
,
device
=
a1q
.
device
,
dtype
=
a1
.
dtype
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13
=
buffers
.
workspace13
.
get
(
workspace13_shape
,
device
=
device
,
dtype
=
workspace_dtype
)
workspace2
=
buffers
.
workspace2
.
get
(
workspace2_shape
,
device
=
device
,
dtype
=
workspace_dtype
)
def
slice_input_tensors
(
chunk_idx
:
int
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
s
=
chunk_idx
*
CHUNK_SIZE
e
=
min
(
s
+
CHUNK_SIZE
,
M
)
return
(
a1q
[
s
:
e
],
_chunk_scales
(
a1q_scale
,
s
,
e
),
_chunk_scales
(
self
.
fused_experts
.
a2_scale
,
s
,
e
),
topk_ids
[
s
:
e
],
topk_weights
[
s
:
e
],
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if
num_chunks
==
1
and
prod
(
workspace13_shape
)
>=
prod
(
fused_out_shape
):
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
else
:
fused_out
=
buffers
.
fused_out
.
get
(
fused_out_shape
,
device
=
device
,
dtype
=
out_dtype
)
def
slice_output_tensor
(
chunk_idx
:
int
)
->
torch
.
Tensor
:
assert
fused_out
.
size
(
0
)
%
M
==
0
,
(
f
"fused_out shape
{
fused_out
.
shape
}
vs M
{
M
}
"
)
factor
=
fused_out
.
size
(
0
)
//
M
out_chunk_size
=
CHUNK_SIZE
*
factor
s
=
chunk_idx
*
out_chunk_size
e
=
min
(
s
+
out_chunk_size
,
fused_out
.
size
(
0
))
return
fused_out
[
s
:
e
]
def
slice_expert_tokens_metadata
(
full_expert_tokens_meta
:
ExpertTokensMetadata
,
chunk_topk_ids
:
torch
.
Tensor
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
)
->
ExpertTokensMetadata
:
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens
=
count_expert_num_tokens
(
chunk_topk_ids
,
local_num_experts
,
expert_map
)
return
workspace13
,
workspace2
,
fused_out
c_expert_num_tokens_cpu
=
None
need_expert_num_tokens_cpu
=
(
full_expert_tokens_meta
.
expert_num_tokens_cpu
is
not
None
)
if
need_expert_num_tokens_cpu
:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu
=
c_expert_num_tokens
.
to
(
"cpu"
,
non_blocking
=
False
)
return
ExpertTokensMetadata
(
expert_num_tokens
=
c_expert_num_tokens
,
expert_num_tokens_cpu
=
c_expert_num_tokens_cpu
,
)
@
staticmethod
def
_slice_output_tensor
(
fused_out
:
torch
.
Tensor
,
chunk_idx
:
int
,
num_chunks
:
int
,
CHUNK_SIZE
:
int
,
M
:
int
,
)
->
torch
.
Tensor
:
if
num_chunks
==
1
:
return
fused_out
for
chunk_idx
in
range
(
num_chunks
):
c_a1q
,
c_a1q_scale
,
c_a2_scale
,
c_topk_ids
,
c_topk_weights
=
(
slice_input_tensors
(
chunk_idx
)
)
assert
fused_out
.
size
(
0
)
%
M
==
0
,
f
"fused_out shape
{
fused_out
.
shape
}
vs M
{
M
}
"
factor
=
fused_out
.
size
(
0
)
//
M
out_chunk_size
=
CHUNK_SIZE
*
factor
s
=
chunk_idx
*
out_chunk_size
e
=
min
(
s
+
out_chunk_size
,
fused_out
.
size
(
0
))
return
fused_out
[
s
:
e
]
c_expert_tokens_meta
=
None
if
expert_tokens_meta
is
not
None
:
c_expert_tokens_meta
=
slice_expert_tokens_metadata
(
expert_tokens_meta
,
c_topk_ids
,
local_num_experts
,
expert_map
)
@
staticmethod
def
_slice_expert_tokens_metadata
(
num_chunks
:
int
,
full_expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
chunk_topk_ids
:
torch
.
Tensor
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
ExpertTokensMetadata
]:
if
num_chunks
==
1
or
full_expert_tokens_meta
is
None
:
return
full_expert_tokens_meta
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens
=
count_expert_num_tokens
(
chunk_topk_ids
,
local_num_experts
,
expert_map
)
self
.
_do_fused_experts
(
fused_out
=
slice_output_tensor
(
chunk_idx
),
a1
=
a1
,
a1q
=
c_a1q
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
c_topk_weights
,
topk_ids
=
c_topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
c_a1q_scale
,
a2_scale
=
c_a2_scale
,
expert_tokens_meta
=
c_expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
c_expert_num_tokens_cpu
=
None
need_expert_num_tokens_cpu
=
(
full_expert_tokens_meta
.
expert_num_tokens_cpu
is
not
None
)
if
need_expert_num_tokens_cpu
:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu
=
c_expert_num_tokens
.
to
(
"cpu"
,
non_blocking
=
False
)
return
fused_out
return
ExpertTokensMetadata
(
expert_num_tokens
=
c_expert_num_tokens
,
expert_num_tokens_cpu
=
c_expert_num_tokens_cpu
,
)
def
forw
ar
d
(
def
_prep
ar
e
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
ExpertTokensMetadata
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
a1
=
hidden_states
output
=
a1
if
inplace
and
self
.
shared_experts
is
None
else
torch
.
zeros_like
(
a1
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
...
...
@@ -967,7 +883,7 @@ class FusedMoEModularKernel(torch.nn.Module):
_expert_topk_ids
,
_expert_topk_weights
,
)
=
self
.
prepare_finalize
.
prepare
(
a1
,
hidden_states
,
topk_weights
,
topk_ids
,
global_num_experts
,
...
...
@@ -979,7 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook
()
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
a1
,
hidden_states
,
topk_weights
,
topk_ids
,
global_num_experts
,
...
...
@@ -1019,33 +935,114 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
fused_out
=
None
return
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
if
a1q
.
numel
()
==
0
:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out
=
torch
.
empty_like
(
a1q
).
to
(
dtype
=
a1
.
dtype
)
def
_fused_experts
(
self
,
in_dtype
:
torch
.
dtype
,
a1q
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
)
->
torch
.
Tensor
:
_
,
M_full
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
num_chunks
,
CHUNK_SIZE
=
self
.
_chunk_info
(
M_full
)
def
input_chunk_range
(
chunk_idx
:
int
)
->
tuple
[
int
,
int
]:
if
num_chunks
==
1
:
# Use a1q.size(0) here since batched format does not
# keep M in the first dimension.
return
0
,
a1q
.
size
(
0
)
else
:
s
=
chunk_idx
*
CHUNK_SIZE
e
=
min
(
s
+
CHUNK_SIZE
,
M_full
)
return
s
,
e
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
if
M_full
==
0
:
assert
num_chunks
==
0
workspace13
=
None
workspace2
=
None
fused_out
=
torch
.
empty_like
(
a1q
)
else
:
fused_out
=
self
.
_maybe_chunk_fused_experts
(
a1
=
a1
,
a1q
=
a1q
,
assert
num_chunks
>
0
workspace13
,
workspace2
,
fused_out
=
self
.
_allocate_buffers
(
in_dtype
,
a1q
.
device
,
CHUNK_SIZE
,
M_full
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
for
chunk_idx
in
range
(
num_chunks
):
s
,
e
=
input_chunk_range
(
chunk_idx
)
c_expert_tokens_meta
=
self
.
_slice_expert_tokens_metadata
(
num_chunks
,
expert_tokens_meta
,
topk_ids
[
s
:
e
],
local_num_experts
,
expert_map
,
)
c_fused_out
=
self
.
_slice_output_tensor
(
fused_out
,
chunk_idx
,
num_chunks
,
CHUNK_SIZE
,
M_full
)
self
.
fused_experts
.
apply
(
output
=
c_fused_out
,
hidden_states
=
a1q
[
s
:
e
],
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
[
s
:
e
]
,
topk_ids
=
topk_ids
[
s
:
e
]
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
a1q_scale
,
expert_tokens_meta
=
expert_tokens_meta
,
a1q_scale
=
_slice_scales
(
a1q_scale
,
s
,
e
),
a2_scale
=
_slice_scales
(
self
.
fused_experts
.
a2_scale
,
e
,
e
),
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_tokens_meta
=
c_expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
return
fused_out
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_out
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
prepare_finalize
.
supports_async
():
...
...
@@ -1060,7 +1057,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
else
:
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
output
,
...
...
@@ -1072,7 +1069,7 @@ class FusedMoEModularKernel(torch.nn.Module):
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
...
...
@@ -1100,3 +1097,87 @@ class FusedMoEModularKernel(torch.nn.Module):
else
:
assert
shared_output
is
not
None
return
shared_output
,
output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if
inplace
and
self
.
shared_experts
is
None
:
output
=
hidden_states
else
:
output
=
torch
.
zeros_like
(
hidden_states
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
=
self
.
_prepare
(
hidden_states
,
topk_weights
,
topk_ids
,
global_num_experts
,
expert_map
,
apply_router_weight_on_input
,
)
fused_out
=
self
.
_fused_experts
(
in_dtype
=
hidden_states
.
dtype
,
a1q
=
a1q
,
a1q_scale
=
a1q_scale
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_tokens_meta
=
expert_tokens_meta
,
)
return
self
.
_finalize
(
output
,
fused_out
,
hidden_states
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
da364615
...
...
@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
def
supports_async
(
self
)
->
bool
:
return
True
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
da364615
...
...
@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
return
1
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment