Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da364615
Unverified
Commit
da364615
authored
Oct 08, 2025
by
bnellnm
Committed by
GitHub
Oct 08, 2025
Browse files
[Kernels] Modular kernel refactor (#24812)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
f08919b7
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
592 additions
and
490 deletions
+592
-490
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+23
-14
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
...s/kernels/moe/modular_kernel_tools/make_feature_matrix.py
+1
-1
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+12
-14
tests/kernels/moe/test_modular_kernel_combinations.py
tests/kernels/moe/test_modular_kernel_combinations.py
+96
-48
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+5
-10
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
...cutor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
+4
-7
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+21
-36
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+4
-8
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
...r/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+8
-0
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+7
-10
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+2
-4
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+37
-53
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+354
-273
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+3
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+3
-0
No files found.
tests/kernels/moe/modular_kernel_tools/common.py
View file @
da364615
...
@@ -209,18 +209,18 @@ class Config:
...
@@ -209,18 +209,18 @@ class Config:
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
return
info
.
backend
return
info
.
backend
def
is_valid
(
self
):
def
is_valid
(
self
)
->
tuple
[
bool
,
Optional
[
str
]]
:
# Check prepare-finalize and fused-experts compatibility
# Check prepare-finalize and fused-experts compatibility
if
self
.
is_batched_prepare_finalize
():
if
self
.
is_batched_prepare_finalize
():
if
not
self
.
is_batched_fused_experts
():
if
not
self
.
is_batched_fused_experts
():
return
False
return
False
,
"Mismatched format."
else
:
else
:
if
not
self
.
is_standard_fused_experts
():
if
not
self
.
is_standard_fused_experts
():
return
False
return
False
,
"Mismatched format."
use_chunking
=
self
.
fused_moe_chunk_size
is
not
None
use_chunking
=
self
.
fused_moe_chunk_size
is
not
None
if
use_chunking
and
not
self
.
is_fe_supports_chunking
():
if
use_chunking
and
not
self
.
is_fe_supports_chunking
():
return
False
return
False
,
"Chunking not supported."
# Check quantization sanity
# Check quantization sanity
if
(
if
(
...
@@ -229,7 +229,7 @@ class Config:
...
@@ -229,7 +229,7 @@ class Config:
+
int
(
self
.
quant_block_shape
is
not
None
)
+
int
(
self
.
quant_block_shape
is
not
None
)
)
>
1
:
)
>
1
:
# invalid quant config
# invalid quant config
return
False
return
False
,
f
"Bad quant_config
{
self
.
quant_config
}
."
# check type support
# check type support
if
self
.
quant_dtype
is
None
:
if
self
.
quant_dtype
is
None
:
...
@@ -237,34 +237,43 @@ class Config:
...
@@ -237,34 +237,43 @@ class Config:
self
.
dtype
not
in
self
.
pf_supported_types
()
self
.
dtype
not
in
self
.
pf_supported_types
()
or
self
.
dtype
not
in
self
.
fe_supported_types
()
or
self
.
dtype
not
in
self
.
fe_supported_types
()
):
):
return
False
return
False
,
(
f
"Unsupported type
{
self
.
dtype
}
not in "
f
"
{
self
.
pf_supported_types
()
}
and "
f
"
{
self
.
fe_supported_types
()
}
."
)
else
:
else
:
if
(
if
(
self
.
quant_dtype
not
in
self
.
pf_supported_types
()
self
.
quant_dtype
not
in
self
.
pf_supported_types
()
or
self
.
quant_dtype
not
in
self
.
fe_supported_types
()
or
self
.
quant_dtype
not
in
self
.
fe_supported_types
()
):
):
return
False
return
False
,
(
f
"Unsupported quant type
{
self
.
quant_dtype
}
"
f
"not in
{
self
.
pf_supported_types
()
}
and "
f
"
{
self
.
fe_supported_types
()
}
."
)
# Check block quanization support
# Check block quanization support
is_block_quatized
=
self
.
quant_block_shape
is
not
None
is_block_quatized
=
self
.
quant_block_shape
is
not
None
if
is_block_quatized
and
self
.
quant_dtype
is
None
:
if
is_block_quatized
and
self
.
quant_dtype
is
None
:
return
False
return
False
,
"No block quantization support."
if
is_block_quatized
and
not
self
.
is_block_quant_supported
():
if
is_block_quatized
and
not
self
.
is_block_quant_supported
():
return
False
return
False
,
"Mismatched block quantization support."
# deep_gemm only works with block-quantized
# deep_gemm only works with block-quantized
if
self
.
needs_deep_gemm
()
and
not
is_block_quatized
:
if
self
.
needs_deep_gemm
()
and
not
is_block_quatized
:
return
False
return
False
,
"Needs DeepGEMM but not block quantized."
# Check dependencies (turn into asserts?)
# Check dependencies (turn into asserts?)
if
self
.
needs_deep_ep
()
and
not
has_deep_ep
():
if
self
.
needs_deep_ep
()
and
not
has_deep_ep
():
return
False
return
False
,
"Needs DeepEP, but DeepEP not available."
if
self
.
needs_deep_gemm
()
and
not
has_deep_gemm
():
if
self
.
needs_deep_gemm
()
and
not
has_deep_gemm
():
return
False
return
False
,
"Needs DeepGEMM, but DeepGEMM not available."
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
return
False
return
False
,
"Needs PPLX, but PPLX not available."
return
True
return
True
,
None
@
dataclass
@
dataclass
...
...
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
View file @
da364615
...
@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
...
@@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
)
)
success
=
None
success
=
None
if
config
.
is_valid
():
if
config
.
is_valid
()
[
0
]
:
print
(
f
"Running config :
{
config
.
describe
()
}
..."
)
print
(
f
"Running config :
{
config
.
describe
()
}
..."
)
try
:
try
:
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
...
...
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
da364615
...
@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
...
@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_prepare_and_finalize
(
register_prepare_and_finalize
(
FlashInferCutlassMoEPrepareAndFinalize
,
FlashInferCutlassMoEPrepareAndFinalize
,
standard_format
,
standard_format
,
nvfp4_types
,
nvfp4_types
+
fp8_types
,
blocked_quantization_support
=
True
,
blocked_quantization_support
=
True
,
backend
=
None
,
backend
=
None
,
force_multigpu
=
True
,
force_multigpu
=
True
,
...
@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
...
@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
register_experts
(
register_experts
(
FlashInferExperts
,
FlashInferExperts
,
standard_format
,
standard_format
,
nvfp4_types
,
nvfp4_types
+
fp8_types
,
blocked_quantization_support
=
True
,
blocked_quantization_support
=
True
,
supports_chunking
=
True
,
supports_chunking
=
True
,
# Note: this is a hack to get it to run for now
# Note: this is a hack to get it to run for now
...
@@ -274,7 +274,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
...
@@ -274,7 +274,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant
=
False
,
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
needs_deep_gemm
=
True
,
)
)
(
register_experts
(
register_experts
(
DeepGemmExperts
,
DeepGemmExperts
,
standard_format
,
standard_format
,
...
@@ -284,7 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
...
@@ -284,7 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
supports_expert_map
=
True
,
supports_expert_map
=
True
,
needs_matching_quant
=
False
,
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
needs_deep_gemm
=
True
,
),
)
)
register_experts
(
register_experts
(
BatchedTritonOrDeepGemmExperts
,
BatchedTritonOrDeepGemmExperts
,
...
@@ -464,7 +462,7 @@ def make_fused_experts(
...
@@ -464,7 +462,7 @@ def make_fused_experts(
print
(
f
"Making BatchedTritonOrDeepGemmExperts
{
kwargs
}
..."
)
print
(
f
"Making BatchedTritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonOrDeepGemmExperts
(
**
kwargs
)
experts
=
BatchedTritonOrDeepGemmExperts
(
**
kwargs
)
elif
fused_experts_type
==
DeepGemmExperts
:
elif
fused_experts_type
==
DeepGemmExperts
:
print
(
"Making DeepGemmExperts {quant_config} ..."
)
print
(
f
"Making DeepGemmExperts
{
quant_config
}
..."
)
experts
=
DeepGemmExperts
(
quant_config
)
experts
=
DeepGemmExperts
(
quant_config
)
elif
fused_experts_type
==
TritonExperts
:
elif
fused_experts_type
==
TritonExperts
:
kwargs
=
quant_kwargs
kwargs
=
quant_kwargs
...
...
tests/kernels/moe/test_modular_kernel_combinations.py
View file @
da364615
...
@@ -5,7 +5,7 @@ import copy
...
@@ -5,7 +5,7 @@ import copy
import
textwrap
import
textwrap
import
traceback
import
traceback
from
itertools
import
product
from
itertools
import
product
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -13,10 +13,9 @@ import torch
...
@@ -13,10 +13,9 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils
import
cuda_device_count_stateless
,
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
...utils
import
multi_gpu_test
from
.modular_kernel_tools.common
import
(
from
.modular_kernel_tools.common
import
(
Config
,
Config
,
RankTensors
,
RankTensors
,
...
@@ -132,7 +131,8 @@ def rank_worker(
...
@@ -132,7 +131,8 @@ def rank_worker(
def
run
(
config
:
Config
,
verbose
:
bool
):
def
run
(
config
:
Config
,
verbose
:
bool
):
assert
config
.
is_valid
()
assert
config
.
is_valid
()[
0
]
assert
not
is_nyi_config
(
config
)
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
weights
:
WeightTensors
=
WeightTensors
.
make
(
config
)
...
@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
...
@@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
return
not
info
.
supports_expert_map
return
not
info
.
supports_expert_map
@
pytest
.
mark
.
parametrize
(
"k"
,
Ks
)
def
generate_valid_test_cases
(
@
pytest
.
mark
.
parametrize
(
"n"
,
Ns
)
world_size
:
int
,
prepare_finalize_types
@
pytest
.
mark
.
parametrize
(
"e"
,
Es
)
)
->
list
[
tuple
[
Any
,
...]]:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPEs
)
cases
=
[]
@
pytest
.
mark
.
parametrize
(
"quant_config"
,
MK_QUANT_CONFIGS
)
total
=
0
for
k
,
n
,
e
,
dtype
,
quant_config
,
combination
,
chunk_size
in
product
(
Ks
,
Ns
,
Es
,
DTYPEs
,
MK_QUANT_CONFIGS
,
product
(
prepare_finalize_types
,
MK_FUSED_EXPERT_TYPES
),
FUSED_MOE_CHUNK_SIZEs
,
):
total
=
total
+
1
config
=
Config
(
Ms
=
Ms
,
K
=
k
,
N
=
n
,
E
=
e
,
topks
=
TOPKs
,
dtype
=
dtype
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
],
fused_experts_type
=
combination
[
1
],
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
)
# TODO(bnell): figure out how to get verbose flag here.
verbose
=
False
# pytestconfig.getoption('verbose') > 0
valid
,
reason
=
config
.
is_valid
()
if
not
valid
:
if
verbose
:
print
(
f
"Test config
{
config
}
is not valid:
{
reason
}
"
)
continue
if
is_nyi_config
(
config
):
if
verbose
:
print
(
f
"Test config
{
config
}
is nyi."
)
continue
cases
.
append
(
(
k
,
n
,
e
,
dtype
,
quant_config
,
combination
[
0
],
combination
[
1
],
chunk_size
,
world_size
,
)
)
print
(
f
"
{
len
(
cases
)
}
of
{
total
}
valid configs generated."
)
return
cases
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"combination"
,
product
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size"
,
generate_valid_test_cases
(
world_size
=
2
,
prepare_finalize_types
=
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
),
)
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
meets_multi_gpu_requirements
@
meets_multi_gpu_requirements
def
test_modular_kernel_combinations_multigpu
(
def
test_modular_kernel_combinations_multigpu
(
k
:
int
,
k
:
int
,
...
@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
...
@@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
e
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
quant_config
:
Optional
[
TestMoEQuantConfig
],
quant_config
:
Optional
[
TestMoEQuantConfig
],
combination
:
tuple
[
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPermuteExpertsUnpermute
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
],
chunk_size
:
Optional
[
int
],
fused_moe_chunk_size
:
Optional
[
int
],
world_size
:
int
,
world_size
:
int
,
pytestconfig
,
pytestconfig
,
):
):
if
cuda_device_count_stateless
()
<
world_size
:
pytest
.
skip
(
f
"Not enough GPUs available to run, got "
f
"
{
cuda_device_count_stateless
()
}
exepected "
f
"
{
world_size
}
."
)
config
=
Config
(
config
=
Config
(
Ms
=
Ms
,
Ms
=
Ms
,
K
=
k
,
K
=
k
,
...
@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
...
@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
topks
=
TOPKs
,
topks
=
TOPKs
,
dtype
=
dtype
,
dtype
=
dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
]
,
prepare_finalize_type
=
prepare_finalize_type
,
fused_experts_type
=
combination
[
1
]
,
fused_experts_type
=
fused_experts_type
,
fused_moe_chunk_size
=
fused_moe_
chunk_size
,
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
world_size
=
world_size
,
)
)
if
not
config
.
is_valid
():
pytest
.
skip
(
f
"Tests config
{
config
}
is not valid. Skipping ..."
)
if
is_nyi_config
(
config
):
pytest
.
skip
(
f
"Tests config
{
config
}
is nyi. Skipping ..."
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
run
(
config
,
verbosity
>
0
)
run
(
config
,
verbosity
>
0
)
@
pytest
.
mark
.
parametrize
(
"k"
,
Ks
)
@
pytest
.
mark
.
parametrize
(
"n"
,
Ns
)
@
pytest
.
mark
.
parametrize
(
"e"
,
Es
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPEs
)
@
pytest
.
mark
.
parametrize
(
"quant_config"
,
MK_QUANT_CONFIGS
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"combination"
,
product
(
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size"
,
generate_valid_test_cases
(
world_size
=
1
,
prepare_finalize_types
=
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
),
)
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
def
test_modular_kernel_combinations_singlegpu
(
def
test_modular_kernel_combinations_singlegpu
(
k
:
int
,
k
:
int
,
n
:
int
,
n
:
int
,
e
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
quant_config
:
Optional
[
TestMoEQuantConfig
],
quant_config
:
Optional
[
TestMoEQuantConfig
],
combination
:
tuple
[
prepare_finalize_type
:
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPrepareAndFinalize
,
mk
.
FusedMoEPermuteExpertsUnpermute
fused_experts_type
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
],
chunk_size
:
Optional
[
int
],
fused_moe_chunk_size
:
Optional
[
int
],
world_size
:
int
,
world_size
:
int
,
pytestconfig
,
pytestconfig
,
):
):
...
@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
...
@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
topks
=
TOPKs
,
topks
=
TOPKs
,
dtype
=
dtype
,
dtype
=
dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prepare_finalize_type
=
combination
[
0
]
,
prepare_finalize_type
=
prepare_finalize_type
,
fused_experts_type
=
combination
[
1
]
,
fused_experts_type
=
fused_experts_type
,
fused_moe_chunk_size
=
fused_moe_
chunk_size
,
fused_moe_chunk_size
=
chunk_size
,
world_size
=
world_size
,
world_size
=
world_size
,
)
)
if
not
config
.
is_valid
():
pytest
.
skip
(
f
"Tests config
{
config
}
is not valid. Skipping ..."
)
if
is_nyi_config
(
config
):
pytest
.
skip
(
f
"Tests config
{
config
}
is nyi. Skipping ..."
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
verbosity
=
pytestconfig
.
getoption
(
"verbose"
)
run
(
config
,
verbosity
>
0
)
run
(
config
,
verbosity
>
0
)
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
da364615
...
@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
topk
:
int
,
topk
:
int
,
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
a
.
dim
()
==
2
# FIXME (varun): We should be able to dispatch only from the leader
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
num_dispatchers
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
num_experts
=
local_num_experts
max_num_tokens
=
(
max_num_tokens
=
M
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
a
.
size
(
0
)
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
)
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
w2
.
size
(
1
)
==
K
assert
w2
.
size
(
1
)
==
K
E
,
max_num_tokens
,
N
,
K
,
top_k_num
=
self
.
moe_problem_size
(
E
,
max_num_tokens
,
N
,
K
,
_
=
self
.
moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
hidden_states
,
w1
,
w2
,
topk_ids
)
)
...
...
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
View file @
da364615
...
@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
bte_war
is
not
None
assert
bte_war
is
not
None
return
bte_war
return
bte_war
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
act_dtype
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_metadata
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
# even if we fall back to triton later, e.g. if expert maps are set.
if
self
.
allow_deep_gemm
:
if
self
.
allow_deep_gemm
:
assert
self
.
batched_deep_gemm_experts
is
not
None
assert
self
.
batched_deep_gemm_experts
is
not
None
return
self
.
batched_deep_gemm_experts
.
workspace_shapes
(
return
self
.
batched_deep_gemm_experts
.
workspace_shapes
(
a
,
aq
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
else
:
else
:
assert
self
.
batched_triton_experts
is
not
None
assert
self
.
batched_triton_experts
is
not
None
return
self
.
batched_triton_experts
.
workspace_shapes
(
return
self
.
batched_triton_experts
.
workspace_shapes
(
a
,
aq
,
M
,
M
,
N
,
N
,
K
,
K
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
da364615
...
@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# topk weights and reduction are fused in moe_unpermute cuda kernel
# topk weights and reduction are fused in moe_unpermute cuda kernel
return
TopKWeightAndReduceNoOP
()
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
))
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
return
(
workspace1
,
workspace2
,
output
)
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
class
CutlassBatchedExpertsFp8
(
CutlassExpertsFp8Base
):
class
CutlassBatchedExpertsFp8
(
CutlassExpertsFp8Base
):
...
@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...
@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def
supports_expert_map
(
self
)
->
bool
:
def
supports_expert_map
(
self
)
->
bool
:
return
False
return
False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...
@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
padded_M
=
aq
.
size
(
1
)
num_dp
=
self
.
num_dispatchers
num_dp
=
self
.
num_dispatchers
assert
num_dp
is
not
None
assert
num_dp
is
not
None
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
,
K
))
workspace1
=
(
self
.
max_experts_per_worker
,
M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
//
2
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
return
(
return
(
workspace1
,
workspace2
,
output
)
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
def
cutlass_moe_fp8
(
def
cutlass_moe_fp8
(
...
@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
:
tuple
[
int
,
...]
=
()
workspace1
:
tuple
[
int
,
...]
=
()
workspace2
:
tuple
[
int
,
...]
=
()
workspace2
:
tuple
[
int
,
...]
=
()
output
:
tuple
[
int
,
...]
=
()
output
:
tuple
[
int
,
...]
=
()
if
self
.
use_batched_format
:
if
self
.
use_batched_format
:
padded_M
=
aq
.
size
(
1
)
workspace1
=
(
self
.
max_experts_per_worker
,
M
,
max
(
N
,
K
))
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
,
(
N
//
2
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
,
(
N
//
2
))
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
else
:
else
:
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace2
=
(
M
*
topk
,
N
)
workspace2
=
(
M
*
topk
,
N
)
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
return
(
workspace1
,
workspace2
,
output
)
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
,
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
da364615
...
@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
self
.
block_shape
is
not
None
assert
self
.
block_shape
is
not
None
block_m
=
self
.
block_shape
[
0
]
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
M_sum
=
compute_aligned_M
(
...
@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
da364615
...
@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
@
property
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
return
mk
.
FusedMoEActivationFormat
.
Standard
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
da364615
...
@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
@
property
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
da364615
...
@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# We use global_num_experts due to how moe_align_block_size handles
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
# expert_maps.
"""
"""
...
@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
of each tuple must be the number of tokens.
"""
"""
aq_m
,
aq_n
=
aq
.
shape
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,)
workspace2
=
(
0
,)
output_shape
=
(
aq_m
,
aq_n
*
2
)
if
self
.
quant_dtype
==
"nvfp4"
else
(
aq_m
,
aq_n
)
output_shape
=
(
M
,
K
*
2
if
self
.
quant_dtype
==
"nvfp4"
else
K
)
workspace_dtype
=
a
.
dtype
workspace1
=
output_shape
# The workspace is determined by `aq`, since it comes after any
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
# potential communication op and is involved in the expert computation.
return
(
workspace1
,
workspace2
,
output_shape
,
workspace_dtype
)
return
(
workspace1
,
workspace2
,
output_shape
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
View file @
da364615
...
@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
...
@@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
)
)
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.utils.flashinfer
import
nvfp4_block_scale_interleave
from
vllm.utils.flashinfer
import
nvfp4_block_scale_interleave
...
@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
_apply_router_weight_on_input
(
def
_apply_router_weight_on_input
(
self
,
self
,
a1
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
...
@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
...
@@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
)
->
None
:
assert
isinstance
(
weight_and_reduce_impl
,
TopKWeightAndReduceNoOP
)
if
self
.
use_dp
:
if
self
.
use_dp
:
fused_expert_output
=
get_dp_group
().
reduce_scatterv
(
fused_expert_output
=
get_dp_group
().
reduce_scatterv
(
fused_expert_output
,
dim
=
0
,
sizes
=
get_local_sizes
()
fused_expert_output
,
dim
=
0
,
sizes
=
get_local_sizes
()
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
da364615
...
@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
prepare
(
def
prepare
(
self
,
self
,
a1
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
...
@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
a
.
dim
()
==
2
num_dp
=
self
.
num_dispatchers
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
num_experts
=
local_num_experts
workspace13
=
(
num_experts
,
self
.
max_num_tokens
*
num_dp
,
K
)
workspace13
=
(
num_experts
,
self
.
max_num_tokens
*
num_dp
,
K
)
workspace2
=
(
self
.
max_num_tokens
*
num_dp
,
N
)
workspace2
=
(
self
.
max_num_tokens
*
num_dp
,
N
)
output
=
workspace13
output
=
workspace13
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
dequant
(
self
,
t
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
dequant
(
self
,
t
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
self
.
quant_config
.
is_quantized
assert
self
.
quant_config
.
is_quantized
...
@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
a
.
dim
()
==
2
num_dp
=
self
.
num_dispatchers
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
num_experts
=
local_num_experts
max_num_tokens
=
self
.
max_num_tokens
max_num_tokens
=
self
.
max_num_tokens
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dp
,
(
N
//
2
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dp
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dp
,
K
)
output
=
(
num_experts
,
max_num_tokens
*
num_dp
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace13
,
workspace2
,
output
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
da364615
...
@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
# the fused_marlin_moe() function, the final torch.sum(), is defined
# essentially as,
# essentially as,
...
@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
=
(
M
*
topk
*
max
(
2
*
N
,
K
),)
workspace2
=
(
M
*
topk
*
max
(
2
*
N
,
K
),)
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
da364615
...
@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
,
topk
,
max
(
N
//
2
,
K
))
workspace1
=
(
M
,
topk
,
max
(
N
//
2
,
K
))
workspace2
=
(
M
,
topk
,
max
(
N
,
K
))
workspace2
=
(
M
,
topk
,
max
(
N
,
K
))
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
da364615
...
@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
...
@@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
...
@@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# workspace are allocated inside the kernel
# workspace are allocated inside the kernel
workspace1
=
(
M
,
K
)
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,
0
)
workspace2
=
(
0
,
0
)
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
da364615
...
@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
->
Optional
[
FusedMoEQuantConfig
]:
)
->
Optional
[
FusedMoEQuantConfig
]:
raise
NotImplementedError
raise
NotImplementedError
@
property
def
using_modular_kernel
(
self
)
->
bool
:
return
self
.
fused_experts
is
not
None
@
abstractmethod
@
abstractmethod
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1237,38 +1241,24 @@ class FusedMoE(CustomOp):
...
@@ -1237,38 +1241,24 @@ class FusedMoE(CustomOp):
self
.
batched_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
self
.
batched_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
self
.
batched_router_logits
:
Optional
[
torch
.
Tensor
]
=
None
self
.
batched_router_logits
:
Optional
[
torch
.
Tensor
]
=
None
# TODO(bnell): flashinfer uses non-batched format.
if
self
.
use_dp_chunking
:
# Does it really need a batched buffer?
states_shape
:
tuple
[
int
,
...]
if
(
logits_shape
:
tuple
[
int
,
...]
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
self
.
moe_config
.
use_flashinfer_cutlass_kernels
):
if
vllm_config
.
parallel_config
.
enable_dbo
:
self
.
batched_hidden_states
=
torch
.
zeros
(
(
2
,
moe
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
# Note here we use `num_experts` which is logical expert count
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
if
vllm_config
.
parallel_config
.
enable_dbo
:
(
2
,
moe
.
max_num_tokens
,
num_experts
),
states_shape
=
(
2
,
moe
.
max_num_tokens
,
self
.
hidden_size
)
dtype
=
moe
.
in_dtype
,
logits_shape
=
(
2
,
moe
.
max_num_tokens
,
num_experts
)
device
=
torch
.
cuda
.
current_device
(),
)
else
:
else
:
states_shape
=
(
moe
.
max_num_tokens
,
self
.
hidden_size
)
logits_shape
=
(
moe
.
max_num_tokens
,
num_experts
)
self
.
batched_hidden_states
=
torch
.
zeros
(
self
.
batched_hidden_states
=
torch
.
zeros
(
(
moe
.
max_num_tokens
,
self
.
hidden_size
),
states_shape
,
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
()
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
)
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
self
.
batched_router_logits
=
torch
.
zeros
(
(
moe
.
max_num_tokens
,
num_experts
),
logits_shape
,
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
()
dtype
=
moe
.
in_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
)
@
property
@
property
...
@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
...
@@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
and
self
.
moe_config
.
use_flashinfer_cutlass_kernels
and
self
.
moe_config
.
use_flashinfer_cutlass_kernels
)
)
@
property
def
use_dp_chunking
(
self
)
->
bool
:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
def
update_expert_map
(
self
):
def
update_expert_map
(
self
):
# ep_size and ep_rank should already be updated
# ep_size and ep_rank should already be updated
assert
self
.
expert_map
is
not
None
assert
self
.
expert_map
is
not
None
...
@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
...
@@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output
Therefore it is required that we reduce the shared_experts output
early.
early.
"""
"""
assert
self
.
quant_method
is
not
None
return
(
return
(
self
.
use_pplx_kernels
self
.
quant_method
.
fused_experts
is
not
None
or
self
.
use_deepep_ht_kernels
and
self
.
quant_method
.
fused_experts
.
output_is_reduced
()
or
self
.
use_deepep_ll_kernels
)
)
def
maybe_all_reduce_tensor_model_parallel
(
self
,
final_hidden_states
:
torch
.
Tensor
):
def
maybe_all_reduce_tensor_model_parallel
(
self
,
final_hidden_states
:
torch
.
Tensor
):
"""
"""
The pplx
combine kernel reduce
s
across GPU ranks by default.
Some
combine kernel
s
reduce across GPU ranks by default.
"""
"""
if
(
if
self
.
must_reduce_shared_expert_outputs
():
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
or
self
.
use_deepep_ll_kernels
):
return
final_hidden_states
return
final_hidden_states
else
:
else
:
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
...
@@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
self
.
ensure_moe_quant_config
()
self
.
ensure_moe_quant_config
()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
if
self
.
use_dp_chunking
:
# only when data parallelism (DP) is enabled.
_use_flashinfer_cutlass_kernels
=
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
_use_flashinfer_cutlass_kernels
):
return
self
.
forward_impl_chunked
(
hidden_states
,
router_logits
)
return
self
.
forward_impl_chunked
(
hidden_states
,
router_logits
)
do_naive_dispatch_combine
:
bool
=
(
do_naive_dispatch_combine
:
bool
=
(
self
.
dp_size
>
1
self
.
dp_size
>
1
and
not
self
.
quant_method
.
using_modular_kernel
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
and
not
self
.
moe_config
.
use_flashinfer_cutlass_kernels
)
)
# If there are shared experts but we are not using a modular kernel, the
# If there are shared experts but we are not using a modular kernel, the
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
da364615
...
@@ -337,6 +337,14 @@ class FusedMoEPrepareAndFinalize(ABC):
...
@@ -337,6 +337,14 @@ class FusedMoEPrepareAndFinalize(ABC):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
output_is_reduced
(
self
)
->
bool
:
"""
Indicates whether or not the output of finalize is reduced across all
ranks.
"""
raise
NotImplementedError
# TODO: add supported activations method (return string)
# TODO: add supported activations method (return string)
class
FusedMoEPermuteExpertsUnpermute
(
ABC
):
class
FusedMoEPermuteExpertsUnpermute
(
ABC
):
...
@@ -493,11 +501,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -493,11 +501,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
"""
Workspace type: The dtype to use for the workspace tensors.
"""
return
act_dtype
@
abstractmethod
@
abstractmethod
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
M
:
int
,
N
:
int
,
N
:
int
,
K
:
int
,
K
:
int
,
...
@@ -505,22 +517,33 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -505,22 +517,33 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]
,
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
"""
Compute the shapes for the temporary and final outputs of the two gemms
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
workspace for the last gemm.
Inputs:
- M: number of tokens.
- N: Row (or column) dimension of expert weights.
- K: hidden dimension
- topk: The number of top-k experts to select.
- global_num_experts: global number of experts.
- local_num_experts: local number of experts due to DP/EP.
- expert_tokens_meta: number of tokens per expert metadata for batched
format.
Returns a tuple of:
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: workspace shapes can be 0 if the workspace is not needed.
- Note: in order for activation chunking to work, the first dimension
But in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
of each tuple must be the number of tokens when the shape is
not 0.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -561,7 +584,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -561,7 +584,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
):
)
->
None
:
"""
"""
This function computes the intermediate result of a Mixture of Experts
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
(MoE) layer using two sets of weights, w1 and w2.
...
@@ -600,7 +623,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -600,7 +623,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise
NotImplementedError
raise
NotImplementedError
def
_
chunk
_scales
(
def
_
slice
_scales
(
scales
:
Optional
[
torch
.
Tensor
],
start
:
int
,
end
:
int
scales
:
Optional
[
torch
.
Tensor
],
start
:
int
,
end
:
int
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
if
scales
is
not
None
:
if
scales
is
not
None
:
...
@@ -615,9 +638,10 @@ class SharedResizableBuffer:
...
@@ -615,9 +638,10 @@ class SharedResizableBuffer:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
buffer
=
None
self
.
buffer
=
None
def
get
(
self
,
shape
:
tuple
[
int
,
...],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
def
get
(
if
shape
==
()
or
shape
is
None
:
self
,
shape
:
tuple
[
int
,
...],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
return
None
)
->
torch
.
Tensor
:
assert
shape
!=
()
shape_numel
=
prod
(
shape
)
shape_numel
=
prod
(
shape
)
if
(
if
(
self
.
buffer
is
None
self
.
buffer
is
None
...
@@ -678,31 +702,63 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -678,31 +702,63 @@ class FusedMoEModularKernel(torch.nn.Module):
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
)
)
def
_do_fused_experts
(
def
output_is_reduced
(
self
)
->
bool
:
"""
Indicates whether or not the output of fused MoE kernel
is reduced across all ranks.
"""
return
self
.
prepare_finalize
.
output_is_reduced
()
def
_chunk_info
(
self
,
M
:
int
)
->
tuple
[
int
,
int
]:
"""
Compute number of chunks and chunk size for given M.
If chunking is not supported, set the CHUNK_SIZE to M so we
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
If there are no tokens to process, the number of chunks will be zero.
"""
CHUNK_SIZE
=
(
max
(
M
,
1
)
if
not
self
.
fused_experts
.
supports_chunking
()
else
min
(
M
,
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
)
)
num_chunks
=
cdiv
(
M
,
CHUNK_SIZE
)
# If there are no tokens, then there should be no loop iterations.
assert
M
>
0
or
num_chunks
==
0
return
num_chunks
,
CHUNK_SIZE
def
_allocate_buffers
(
self
,
self
,
fused_out
:
Optional
[
torch
.
Tensor
],
out_dtype
:
torch
.
dtype
,
a1
:
torch
.
Tensor
,
device
:
torch
.
device
,
a1q
:
torch
.
Tensor
,
M_chunk
:
int
,
w1
:
torch
.
Tensor
,
M_full
:
int
,
w2
:
torch
.
Tensor
,
N
:
int
,
topk_weights
:
torch
.
Tensor
,
K
:
int
,
topk_ids
:
torch
.
Tensor
,
top_k
:
int
,
activation
:
str
,
global_num_experts
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
:
"""
_
,
M
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
Allocate temporary and output buffers for the fused experts op.
Inputs:
- out_dtype: output type of workspace and output tensors.
- device: the device of the workspace and output tensors.
See `workspace_shapes` for a description of the remainder of arguments.
Returns a tuple of (workspace13, workspace2, output) tensors.
"""
assert
M_full
>
0
and
M_chunk
>
0
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
(
num_chunks
,
_
=
self
.
_chunk_info
(
M_full
)
self
.
fused_experts
.
workspace_shapes
(
a1
,
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
a1q
,
ubatch_idx
=
dbo_current_ubatch_id
()
M
,
buffers
=
self
.
shared_buffers
[
ubatch_idx
]
workspace_dtype
=
self
.
fused_experts
.
workspace_dtype
(
out_dtype
)
# Get intermediate workspace shapes based off the chunked M size.
workspace13_shape
,
workspace2_shape
,
_
=
self
.
fused_experts
.
workspace_shapes
(
M_chunk
,
N
,
N
,
K
,
K
,
top_k
,
top_k
,
...
@@ -710,147 +766,69 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -710,147 +766,69 @@ class FusedMoEModularKernel(torch.nn.Module):
local_num_experts
,
local_num_experts
,
expert_tokens_meta
,
expert_tokens_meta
,
)
)
)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
# Get final output shape based on the full M size.
ubatch_idx
=
dbo_current_ubatch_id
()
_
,
_
,
fused_out_shape
=
self
.
fused_experts
.
workspace_shapes
(
buffers
=
self
.
shared_buffers
[
ubatch_idx
]
M_full
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
# We can reuse the memory between cache1 and cache3 because by the
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
# time we need cache3, we're done with cache1.
workspace13
=
buffers
.
workspace13
.
get
(
workspace13
=
buffers
.
workspace13
.
get
(
workspace13_shape
,
device
=
a1
.
device
,
dtype
=
workspace_dtype
workspace13_shape
,
device
=
device
,
dtype
=
workspace_dtype
)
)
workspace2
=
buffers
.
workspace2
.
get
(
workspace2
=
buffers
.
workspace2
.
get
(
workspace2_shape
,
device
=
a1
.
device
,
dtype
=
workspace_dtype
workspace2_shape
,
device
=
device
,
dtype
=
workspace_dtype
)
)
assert
fused_out
is
None
or
fused_out
.
shape
==
fused_out_shape
,
(
f
"fused_out
{
fused_out
.
shape
}
but expected
{
fused_out_shape
}
"
)
if
fused_out
is
None
:
# reuse workspace13 for the output
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
self
.
fused_experts
.
apply
(
fused_out
,
a1q
,
w1
,
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_tokens_meta
=
expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
return
fused_out
def
_maybe_chunk_fused_experts
(
self
,
a1
:
torch
.
Tensor
,
a1q
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
)
->
torch
.
Tensor
:
_
,
M
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks
=
cdiv
(
M
,
CHUNK_SIZE
)
# TODO(bnell): get rid of one level here, update slice functions
# to nops on num_chunks==1
if
not
self
.
fused_experts
.
supports_chunking
()
or
num_chunks
==
1
:
return
self
.
_do_fused_experts
(
fused_out
=
None
,
a1
=
a1
,
a1q
=
a1q
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
a1q_scale
,
a2_scale
=
self
.
fused_experts
.
a2_scale
,
expert_tokens_meta
=
expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
# Chunking required case
assert
num_chunks
>
1
# Construct the entire output that can then be processed in chunks.
# Construct the entire output that can then be processed in chunks.
(
_
,
_
,
fused_out_shape
,
_
)
=
self
.
fused_experts
.
workspace_shapes
(
# Reuse workspace13 for the output in the non-chunked case as long
a1
,
# as it is large enough. This will not always be the case for standard
a1q
,
# format experts and with experts that have empty workspaces.
M
,
if
num_chunks
==
1
and
prod
(
workspace13_shape
)
>=
prod
(
fused_out_shape
):
N
,
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
K
,
else
:
top_k
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
ubatch_idx
=
dbo_current_ubatch_id
()
buffers
=
self
.
shared_buffers
[
ubatch_idx
]
fused_out
=
buffers
.
fused_out
.
get
(
fused_out
=
buffers
.
fused_out
.
get
(
fused_out_shape
,
device
=
a1q
.
device
,
dtype
=
a1
.
dtype
fused_out_shape
,
device
=
device
,
dtype
=
out_
dtype
)
)
def
slice_input_tensors
(
return
workspace13
,
workspace2
,
fused_out
@
staticmethod
def
_slice_output_tensor
(
fused_out
:
torch
.
Tensor
,
chunk_idx
:
int
,
chunk_idx
:
int
,
)
->
tuple
[
num_chunks
:
int
,
torch
.
Tensor
,
CHUNK_SIZE
:
int
,
Optional
[
torch
.
Tensor
],
M
:
int
,
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
torch
.
Tensor
,
if
num_chunks
==
1
:
torch
.
Tensor
,
return
fused_out
]:
s
=
chunk_idx
*
CHUNK_SIZE
e
=
min
(
s
+
CHUNK_SIZE
,
M
)
return
(
a1q
[
s
:
e
],
_chunk_scales
(
a1q_scale
,
s
,
e
),
_chunk_scales
(
self
.
fused_experts
.
a2_scale
,
s
,
e
),
topk_ids
[
s
:
e
],
topk_weights
[
s
:
e
],
)
def
slice_output_tensor
(
chunk_idx
:
int
)
->
torch
.
Tensor
:
assert
fused_out
.
size
(
0
)
%
M
==
0
,
f
"fused_out shape
{
fused_out
.
shape
}
vs M
{
M
}
"
assert
fused_out
.
size
(
0
)
%
M
==
0
,
(
f
"fused_out shape
{
fused_out
.
shape
}
vs M
{
M
}
"
)
factor
=
fused_out
.
size
(
0
)
//
M
factor
=
fused_out
.
size
(
0
)
//
M
out_chunk_size
=
CHUNK_SIZE
*
factor
out_chunk_size
=
CHUNK_SIZE
*
factor
s
=
chunk_idx
*
out_chunk_size
s
=
chunk_idx
*
out_chunk_size
e
=
min
(
s
+
out_chunk_size
,
fused_out
.
size
(
0
))
e
=
min
(
s
+
out_chunk_size
,
fused_out
.
size
(
0
))
return
fused_out
[
s
:
e
]
return
fused_out
[
s
:
e
]
def
slice_expert_tokens_metadata
(
@
staticmethod
full_expert_tokens_meta
:
ExpertTokensMetadata
,
def
_slice_expert_tokens_metadata
(
num_chunks
:
int
,
full_expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
chunk_topk_ids
:
torch
.
Tensor
,
chunk_topk_ids
:
torch
.
Tensor
,
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
)
->
ExpertTokensMetadata
:
)
->
Optional
[
ExpertTokensMetadata
]:
if
num_chunks
==
1
or
full_expert_tokens_meta
is
None
:
return
full_expert_tokens_meta
# The existing expert_num_tokens is for the entire a1q
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
# of tokens assigned to each expert.
...
@@ -866,94 +844,32 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -866,94 +844,32 @@ class FusedMoEModularKernel(torch.nn.Module):
# This is blocking as some implementations need the count
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# on the CPU to determine appropriate input/out fused-moe
# buffers
# buffers
c_expert_num_tokens_cpu
=
c_expert_num_tokens
.
to
(
c_expert_num_tokens_cpu
=
c_expert_num_tokens
.
to
(
"cpu"
,
non_blocking
=
False
)
"cpu"
,
non_blocking
=
False
)
return
ExpertTokensMetadata
(
return
ExpertTokensMetadata
(
expert_num_tokens
=
c_expert_num_tokens
,
expert_num_tokens
=
c_expert_num_tokens
,
expert_num_tokens_cpu
=
c_expert_num_tokens_cpu
,
expert_num_tokens_cpu
=
c_expert_num_tokens_cpu
,
)
)
for
chunk_idx
in
range
(
num_chunks
):
def
_prepare
(
c_a1q
,
c_a1q_scale
,
c_a2_scale
,
c_topk_ids
,
c_topk_weights
=
(
slice_input_tensors
(
chunk_idx
)
)
c_expert_tokens_meta
=
None
if
expert_tokens_meta
is
not
None
:
c_expert_tokens_meta
=
slice_expert_tokens_metadata
(
expert_tokens_meta
,
c_topk_ids
,
local_num_experts
,
expert_map
)
self
.
_do_fused_experts
(
fused_out
=
slice_output_tensor
(
chunk_idx
),
a1
=
a1
,
a1q
=
c_a1q
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
c_topk_weights
,
topk_ids
=
c_topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
a1q_scale
=
c_a1q_scale
,
a2_scale
=
c_a2_scale
,
expert_tokens_meta
=
c_expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
return
fused_out
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
global_num_experts
:
int
,
activation
:
str
=
"silu"
,
expert_map
:
Optional
[
torch
.
Tensor
],
global_num_experts
:
int
=
-
1
,
apply_router_weight_on_input
:
bool
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
apply_router_weight_on_input
:
bool
=
False
,
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
Optional
[
torch
.
Tensor
],
Optional
[
ExpertTokensMetadata
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets
The _prepare method is a wrapper around self.prepare_finalize.prepare
of weights, w1 and w2, and top-k gating mechanism.
that handles DBO and async.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
"""
a1
=
hidden_states
output
=
a1
if
inplace
and
self
.
shared_experts
is
None
else
torch
.
zeros_like
(
a1
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
if
not
self
.
prepare_finalize
.
supports_async
():
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# support async prepare/finalize
...
@@ -967,7 +883,7 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -967,7 +883,7 @@ class FusedMoEModularKernel(torch.nn.Module):
_expert_topk_ids
,
_expert_topk_ids
,
_expert_topk_weights
,
_expert_topk_weights
,
)
=
self
.
prepare_finalize
.
prepare
(
)
=
self
.
prepare_finalize
.
prepare
(
a1
,
hidden_states
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_num_experts
,
global_num_experts
,
...
@@ -979,7 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -979,7 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# Overlap shared expert compute with all2all dispatch.
# Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook
()
dbo_maybe_run_recv_hook
()
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
a1
,
hidden_states
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_num_experts
,
global_num_experts
,
...
@@ -1019,33 +935,114 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1019,33 +935,114 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
)
fused_out
=
None
return
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
def
_fused_experts
(
self
,
in_dtype
:
torch
.
dtype
,
a1q
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
expert_tokens_meta
:
Optional
[
ExpertTokensMetadata
],
)
->
torch
.
Tensor
:
_
,
M_full
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
num_chunks
,
CHUNK_SIZE
=
self
.
_chunk_info
(
M_full
)
def
input_chunk_range
(
chunk_idx
:
int
)
->
tuple
[
int
,
int
]:
if
num_chunks
==
1
:
# Use a1q.size(0) here since batched format does not
# keep M in the first dimension.
return
0
,
a1q
.
size
(
0
)
else
:
s
=
chunk_idx
*
CHUNK_SIZE
e
=
min
(
s
+
CHUNK_SIZE
,
M_full
)
return
s
,
e
if
a1q
.
numel
()
==
0
:
# This happens when none of the tokens from the all2all reach this
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
# and can never run into the tensor.numel() == 0 case.
fused_out
=
torch
.
empty_like
(
a1q
).
to
(
dtype
=
a1
.
dtype
)
if
M_full
==
0
:
assert
num_chunks
==
0
workspace13
=
None
workspace2
=
None
fused_out
=
torch
.
empty_like
(
a1q
)
else
:
else
:
fused_out
=
self
.
_maybe_chunk_fused_experts
(
assert
num_chunks
>
0
a1
=
a1
,
workspace13
,
workspace2
,
fused_out
=
self
.
_allocate_buffers
(
a1q
=
a1q
,
in_dtype
,
a1q
.
device
,
CHUNK_SIZE
,
M_full
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
for
chunk_idx
in
range
(
num_chunks
):
s
,
e
=
input_chunk_range
(
chunk_idx
)
c_expert_tokens_meta
=
self
.
_slice_expert_tokens_metadata
(
num_chunks
,
expert_tokens_meta
,
topk_ids
[
s
:
e
],
local_num_experts
,
expert_map
,
)
c_fused_out
=
self
.
_slice_output_tensor
(
fused_out
,
chunk_idx
,
num_chunks
,
CHUNK_SIZE
,
M_full
)
self
.
fused_experts
.
apply
(
output
=
c_fused_out
,
hidden_states
=
a1q
[
s
:
e
],
w1
=
w1
,
w1
=
w1
,
w2
=
w2
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
[
s
:
e
]
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
[
s
:
e
]
,
activation
=
activation
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
a1q_scale
=
a1q_scale
,
a1q_scale
=
_slice_scales
(
a1q_scale
,
s
,
e
),
expert_tokens_meta
=
expert_tokens_meta
,
a2_scale
=
_slice_scales
(
self
.
fused_experts
.
a2_scale
,
e
,
e
),
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_tokens_meta
=
c_expert_tokens_meta
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
return
fused_out
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_out
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
prepare_finalize
.
supports_async
():
if
not
self
.
prepare_finalize
.
supports_async
():
...
@@ -1060,7 +1057,7 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1060,7 +1057,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
)
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
else
:
else
:
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
output
,
output
,
...
@@ -1072,7 +1069,7 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1072,7 +1069,7 @@ class FusedMoEModularKernel(torch.nn.Module):
)
)
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# TODO(lucas): refactor this in the alternative schedules followup
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# currently unpack if we have hook + receiver pair or just
...
@@ -1100,3 +1097,87 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1100,3 +1097,87 @@ class FusedMoEModularKernel(torch.nn.Module):
else
:
else
:
assert
shared_output
is
not
None
assert
shared_output
is
not
None
return
shared_output
,
output
return
shared_output
,
output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if
inplace
and
self
.
shared_experts
is
None
:
output
=
hidden_states
else
:
output
=
torch
.
zeros_like
(
hidden_states
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
=
self
.
_prepare
(
hidden_states
,
topk_weights
,
topk_ids
,
global_num_experts
,
expert_map
,
apply_router_weight_on_input
,
)
fused_out
=
self
.
_fused_experts
(
in_dtype
=
hidden_states
.
dtype
,
a1q
=
a1q
,
a1q_scale
=
a1q_scale
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_tokens_meta
=
expert_tokens_meta
,
)
return
self
.
_finalize
(
output
,
fused_out
,
hidden_states
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
da364615
...
@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
True
def
supports_async
(
self
)
->
bool
:
def
supports_async
(
self
)
->
bool
:
return
True
return
True
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
da364615
...
@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
...
@@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
1
return
1
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
prepare
(
def
prepare
(
self
,
self
,
a1
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment