Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1448b4b
Unverified
Commit
a1448b4b
authored
Nov 11, 2025
by
bnellnm
Committed by
GitHub
Nov 11, 2025
Browse files
[Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064)
parent
fa197020
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1064 additions
and
948 deletions
+1064
-948
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+5
-4
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+3
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+3
-1
vllm/model_executor/layers/fused_moe/all2all_utils.py
vllm/model_executor/layers/fused_moe/all2all_utils.py
+160
-0
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+112
-0
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+164
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+32
-918
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+578
-0
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+6
-23
No files found.
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
a1448b4b
...
...
@@ -6,6 +6,10 @@ import torch
# Fused experts and PrepareFinalize imports
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe
import
TritonExperts
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
...
...
@@ -21,7 +25,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts
,
NaiveBatchedExperts
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
TritonExperts
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
...
...
@@ -399,9 +402,7 @@ def make_prepare_finalize(
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
FusedMoEPrepareAndFinalize
:
if
backend
!=
"naive"
and
backend
is
not
None
:
prepare_finalize
=
FusedMoEMethodBase
.
_maybe_make_prepare_finalize
(
moe
,
quant_config
)
prepare_finalize
=
maybe_make_prepare_finalize
(
moe
,
quant_config
)
assert
prepare_finalize
is
not
None
return
prepare_finalize
elif
prepare_finalize_type
==
FlashInferCutlassMoEPrepareAndFinalize
:
...
...
vllm/lora/layers/fused_moe.py
View file @
a1448b4b
...
...
@@ -25,7 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe
,
try_get_optimal_moe_config
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEModularMethod
from
vllm.model_executor.layers.fused_moe.fused_moe_modular_method
import
(
FusedMoEModularMethod
,
)
class
FusedMoEWithLoRA
(
BaseLayerWithLoRA
):
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
a1448b4b
...
...
@@ -5,9 +5,11 @@ from contextlib import contextmanager
from
typing
import
Any
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
...
...
vllm/model_executor/layers/fused_moe/all2all_utils.py
0 → 100644
View file @
a1448b4b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.distributed
import
(
get_ep_group
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEPrepareAndFinalize
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.import_utils
import
has_deep_ep
,
has_pplx
if
current_platform
.
is_cuda_alike
():
if
has_pplx
():
from
.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
,
pplx_hidden_dim_scale_bytes
,
)
if
has_deep_ep
():
from
.deepep_ht_prepare_finalize
import
DeepEPHTPrepareAndFinalize
from
.deepep_ll_prepare_finalize
import
(
DEEPEP_QUANT_BLOCK_SHAPE
,
DeepEPLLPrepareAndFinalize
,
)
def
maybe_roundup_layer_hidden_size
(
hidden_size
:
int
,
act_dtype
:
torch
.
dtype
,
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
int
:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
Return:
Rounded up hidden_size if rounding up is required based on the configs
and all2all backend.
Original hidden size otherwise.
"""
if
moe_parallel_config
.
use_deepep_ht_kernels
:
hidden_size
=
DeepEPHTPrepareAndFinalize
.
maybe_roundup_layer_hidden_size
(
hidden_size
,
act_dtype
)
if
moe_parallel_config
.
use_deepep_ll_kernels
:
hidden_size
=
DeepEPLLPrepareAndFinalize
.
maybe_roundup_layer_hidden_size
(
hidden_size
)
return
hidden_size
def
maybe_make_prepare_finalize
(
moe
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
|
None
,
)
->
FusedMoEPrepareAndFinalize
|
None
:
if
not
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
prepare_finalize
:
FusedMoEPrepareAndFinalize
|
None
=
None
# TODO: could allow this now
assert
not
moe
.
use_flashinfer_cutlass_kernels
,
"Must be created in modelopt.py"
if
moe
.
use_pplx_kernels
:
assert
quant_config
is
not
None
hidden_dim_bytes
,
hidden_scale_bytes
=
pplx_hidden_dim_scale_bytes
(
moe
.
max_num_tokens
,
moe
.
hidden_dim
,
moe
.
in_dtype
,
quant_config
.
quant_dtype
,
per_act_token_quant
=
quant_config
.
per_act_token_quant
,
block_shape
=
quant_config
.
block_shape
,
)
all_to_all_args
=
dict
(
max_num_tokens
=
moe
.
max_num_tokens
,
num_experts
=
moe
.
num_experts
,
experts_per_token
=
moe
.
experts_per_token
,
# topk
rank
=
all2all_manager
.
rank
,
world_size
=
all2all_manager
.
world_size
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
hidden_dim
=
moe
.
hidden_dim
,
hidden_dim_bytes
=
hidden_dim_bytes
,
hidden_dim_scale_bytes
=
hidden_scale_bytes
,
)
num_dispatchers
=
(
all2all_manager
.
world_size
//
all2all_manager
.
tp_group
.
world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if
not
all2all_manager
.
internode
:
all_to_all_args
[
"group_name"
]
=
all2all_manager
.
cpu_group
.
group_name
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
PplxPrepareAndFinalize
(
handle
,
max_num_tokens
=
moe
.
max_num_tokens
,
num_local_experts
=
moe
.
num_local_experts
,
num_dispatchers
=
num_dispatchers
,
)
elif
moe
.
use_deepep_ht_kernels
:
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
all_to_all_args
=
dict
()
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
DeepEPHTPrepareAndFinalize
(
handle
,
num_dispatchers
=
all2all_manager
.
world_size
,
dp_size
=
all2all_manager
.
dp_world_size
,
rank_expert_offset
=
all2all_manager
.
rank
*
moe
.
num_local_experts
,
)
elif
moe
.
use_deepep_ll_kernels
:
assert
quant_config
is
not
None
all_to_all_args
=
dict
(
max_num_tokens_per_dp_rank
=
moe
.
max_num_tokens
,
token_hidden_size
=
moe
.
hidden_dim
,
num_ep_ranks
=
all2all_manager
.
world_size
,
num_global_experts
=
moe
.
num_experts
,
num_local_experts
=
moe
.
num_experts
//
all2all_manager
.
world_size
,
)
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch
=
(
quant_config
.
quant_dtype
==
current_platform
.
fp8_dtype
()
and
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
handle
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
)
return
prepare_finalize
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
0 → 100644
View file @
a1448b4b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
abstractmethod
from
collections.abc
import
Callable
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
,
)
logger
=
init_logger
(
__name__
)
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
()
self
.
moe
:
FusedMoEConfig
=
moe
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return
False
def
maybe_make_prepare_finalize
(
self
)
->
FusedMoEPrepareAndFinalize
|
None
:
from
.all2all_utils
import
maybe_make_prepare_finalize
return
maybe_make_prepare_finalize
(
self
.
moe
,
self
.
moe_quant_config
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
must select appropriate gemm "
"implementation based on the prepare_finalize"
)
@
abstractmethod
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
raise
NotImplementedError
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
None
@
property
def
supports_eplb
(
self
)
->
bool
:
return
False
@
property
def
allow_inplace
(
self
)
->
bool
:
return
False
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
0 → 100644
View file @
a1448b4b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
,
FusedMoEPrepareAndFinalize
,
)
logger
=
init_logger
(
__name__
)
@
CustomOp
.
register
(
"modular_fused_moe"
)
class
FusedMoEModularMethod
(
FusedMoEMethodBase
,
CustomOp
):
def
__init__
(
self
,
old_quant_method
:
FusedMoEMethodBase
,
experts
:
FusedMoEModularKernel
):
super
().
__init__
(
old_quant_method
.
moe
)
self
.
moe_quant_config
=
old_quant_method
.
moe_quant_config
self
.
fused_experts
=
experts
self
.
disable_expert_map
=
getattr
(
old_quant_method
,
"disable_expert_map"
,
not
self
.
fused_experts
.
supports_expert_map
(),
)
self
.
old_quant_method
=
old_quant_method
logger
.
debug
(
"Swapping out %s"
,
self
.
old_quant_method
.
__class__
.
__name__
)
@
staticmethod
def
make
(
moe_layer
:
torch
.
nn
.
Module
,
old_quant_method
:
FusedMoEMethodBase
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
shared_experts
:
torch
.
nn
.
Module
|
None
,
)
->
"FusedMoEModularMethod"
:
return
FusedMoEModularMethod
(
old_quant_method
,
FusedMoEModularKernel
(
prepare_finalize
,
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
shared_experts
,
),
)
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
self
.
fused_experts
.
prepare_finalize
.
topk_indices_dtype
()
@
property
def
supports_eplb
(
self
)
->
bool
:
return
self
.
old_quant_method
.
supports_eplb
@
property
def
allow_inplace
(
self
)
->
bool
:
return
self
.
old_quant_method
.
allow_inplace
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
self
.
moe_quant_config
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Is getattr needed?
zero_expert_num
=
getattr
(
layer
,
"zero_expert_num"
,
0
)
zero_expert_type
=
getattr
(
layer
,
"zero_expert_type"
,
None
)
if
enable_eplb
:
if
self
.
supports_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
else
:
raise
NotImplementedError
(
"EPLB is not supported for "
f
"
{
self
.
old_quant_method
.
__class__
.
__name__
}
."
)
topk_weights
,
topk_ids
,
zero_expert_result
=
layer
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
global_num_experts
=
global_num_experts
,
zero_expert_num
=
zero_expert_num
,
zero_expert_type
=
zero_expert_type
,
)
result
=
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
self
.
allow_inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
)
if
zero_expert_num
!=
0
and
zero_expert_type
is
not
None
:
assert
not
isinstance
(
result
,
tuple
),
(
"Shared + zero experts are mutually exclusive not yet supported"
)
return
result
,
zero_expert_result
else
:
return
result
vllm/model_executor/layers/fused_moe/layer.py
View file @
a1448b4b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
abstractmethod
from
collections.abc
import
Callable
,
Iterable
from
contextlib
import
nullcontext
from
enum
import
Enum
...
...
@@ -27,17 +26,13 @@ from vllm.forward_context import ForwardContext, get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
biased_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
zero_experts_compute_triton
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEActivationFormat
,
FusedMoEModularKernel
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
)
...
...
@@ -47,35 +42,17 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from
vllm.model_executor.layers.fused_moe.routing_simulator
import
RoutingSimulator
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
is_flashinfer_supporting_global_sf
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.import_utils
import
has_deep_ep
,
has_pplx
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.torch_utils
import
current_stream
,
direct_register_custom_op
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
if
current_platform
.
is_cuda_alike
():
from
.fused_batched_moe
import
BatchedTritonExperts
from
.fused_moe
import
TritonExperts
,
eplb_map_to_physical_and_record
,
fused_experts
if
has_pplx
():
from
.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
,
pplx_hidden_dim_scale_bytes
,
)
if
has_deep_ep
():
from
.deepep_ht_prepare_finalize
import
DeepEPHTPrepareAndFinalize
from
.deepep_ll_prepare_finalize
import
(
DEEPEP_QUANT_BLOCK_SHAPE
,
DeepEPLLPrepareAndFinalize
,
)
from
.fused_moe
import
eplb_map_to_physical_and_record
,
fused_experts
else
:
fused_experts
=
None
# type: ignore
FusedMoEPermuteExpertsUnpermute
=
object
# type: ignore
...
...
@@ -102,6 +79,16 @@ if current_platform.is_tpu():
else
:
fused_moe_pallas
=
None
# type: ignore
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_modular_method
import
(
FusedMoEModularMethod
,
)
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
)
logger
=
init_logger
(
__name__
)
...
...
@@ -112,885 +99,6 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK
=
"block"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
()
self
.
moe
:
FusedMoEConfig
=
moe
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return
False
@
staticmethod
def
_maybe_make_prepare_finalize
(
moe
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
|
None
,
)
->
FusedMoEPrepareAndFinalize
|
None
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
prepare_finalize
:
FusedMoEPrepareAndFinalize
|
None
=
None
# TODO: could allow this now
assert
not
moe
.
use_flashinfer_cutlass_kernels
,
"Must be created in modelopt.py"
if
moe
.
use_pplx_kernels
:
assert
quant_config
is
not
None
hidden_dim_bytes
,
hidden_scale_bytes
=
pplx_hidden_dim_scale_bytes
(
moe
.
max_num_tokens
,
moe
.
hidden_dim
,
moe
.
in_dtype
,
quant_config
.
quant_dtype
,
per_act_token_quant
=
quant_config
.
per_act_token_quant
,
block_shape
=
quant_config
.
block_shape
,
)
all_to_all_args
=
dict
(
max_num_tokens
=
moe
.
max_num_tokens
,
num_experts
=
moe
.
num_experts
,
experts_per_token
=
moe
.
experts_per_token
,
# topk
rank
=
all2all_manager
.
rank
,
world_size
=
all2all_manager
.
world_size
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
hidden_dim
=
moe
.
hidden_dim
,
hidden_dim_bytes
=
hidden_dim_bytes
,
hidden_dim_scale_bytes
=
hidden_scale_bytes
,
)
num_dispatchers
=
(
all2all_manager
.
world_size
//
all2all_manager
.
tp_group
.
world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if
not
all2all_manager
.
internode
:
all_to_all_args
[
"group_name"
]
=
all2all_manager
.
cpu_group
.
group_name
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
PplxPrepareAndFinalize
(
handle
,
max_num_tokens
=
moe
.
max_num_tokens
,
num_local_experts
=
moe
.
num_local_experts
,
num_dispatchers
=
num_dispatchers
,
)
elif
moe
.
use_deepep_ht_kernels
:
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
all_to_all_args
=
dict
()
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
DeepEPHTPrepareAndFinalize
(
handle
,
num_dispatchers
=
all2all_manager
.
world_size
,
dp_size
=
all2all_manager
.
dp_world_size
,
rank_expert_offset
=
all2all_manager
.
rank
*
moe
.
num_local_experts
,
)
elif
moe
.
use_deepep_ll_kernels
:
assert
quant_config
is
not
None
all_to_all_args
=
dict
(
max_num_tokens_per_dp_rank
=
moe
.
max_num_tokens
,
token_hidden_size
=
moe
.
hidden_dim
,
num_ep_ranks
=
all2all_manager
.
world_size
,
num_global_experts
=
moe
.
num_experts
,
num_local_experts
=
moe
.
num_experts
//
all2all_manager
.
world_size
,
)
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch
=
(
quant_config
.
quant_dtype
==
current_platform
.
fp8_dtype
()
and
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
handle
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
)
return
prepare_finalize
def
maybe_make_prepare_finalize
(
self
)
->
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
FusedMoEMethodBase
.
_maybe_make_prepare_finalize
(
self
.
moe
,
self
.
moe_quant_config
)
else
:
return
None
def
maybe_init_modular_kernel
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEModularKernel
|
None
:
assert
self
.
moe
is
not
None
# We must get the quant config here so that the layer is
# completely initialized, i.e. all weights loaded and post
# processed.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
prepare_finalize
=
self
.
maybe_make_prepare_finalize
()
if
prepare_finalize
is
not
None
:
logger
.
debug
(
"%s for %s(%s)"
,
prepare_finalize
.
__class__
.
__name__
,
self
,
id
(
self
)
)
experts
=
self
.
select_gemm_impl
(
prepare_finalize
,
layer
)
return
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
layer
.
shared_experts
,
)
else
:
return
None
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
must select appropriate gemm "
"implementation based on the prepare_finalize"
)
@
abstractmethod
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
raise
NotImplementedError
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
None
@
property
def
supports_eplb
(
self
)
->
bool
:
return
False
@
property
def
allow_inplace
(
self
)
->
bool
:
return
False
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
@
CustomOp
.
register
(
"modular_fused_moe"
)
class
FusedMoEModularMethod
(
FusedMoEMethodBase
,
CustomOp
):
def
__init__
(
self
,
old_quant_method
:
FusedMoEMethodBase
,
fused_experts
:
FusedMoEModularKernel
,
):
super
().
__init__
(
old_quant_method
.
moe
)
# Find better way to copy attributes? Should we even copy attributes?
# self.__dict__.update(old_quant_method.__dict__)
self
.
moe_quant_config
=
old_quant_method
.
moe_quant_config
self
.
fused_experts
=
fused_experts
self
.
disable_expert_map
=
getattr
(
old_quant_method
,
"disable_expert_map"
,
not
fused_experts
.
supports_expert_map
(),
)
self
.
old_quant_method
=
old_quant_method
logger
.
debug
(
"Swapping out %s"
,
self
.
old_quant_method
.
__class__
.
__name__
)
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
self
.
fused_experts
.
prepare_finalize
.
topk_indices_dtype
()
@
property
def
supports_eplb
(
self
)
->
bool
:
return
self
.
old_quant_method
.
supports_eplb
@
property
def
allow_inplace
(
self
)
->
bool
:
return
self
.
old_quant_method
.
allow_inplace
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
self
.
moe_quant_config
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Is getattr needed?
zero_expert_num
=
getattr
(
layer
,
"zero_expert_num"
,
0
)
zero_expert_type
=
getattr
(
layer
,
"zero_expert_type"
,
None
)
if
enable_eplb
:
if
self
.
supports_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
assert
isinstance
(
layer
,
FusedMoE
)
else
:
raise
NotImplementedError
(
"EPLB is not supported for "
f
"
{
self
.
old_quant_method
.
__class__
.
__name__
}
."
)
topk_weights
,
topk_ids
,
zero_expert_result
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
global_num_experts
=
global_num_experts
,
zero_expert_num
=
zero_expert_num
,
zero_expert_type
=
zero_expert_type
,
)
result
=
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
self
.
allow_inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
)
if
zero_expert_num
!=
0
and
zero_expert_type
is
not
None
:
assert
not
isinstance
(
result
,
tuple
),
(
"Shared + zero experts are mutually exclusive not yet supported"
)
return
result
,
zero_expert_result
else
:
return
result
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
self
.
rocm_aiter_moe_enabled
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
if
self
.
rocm_aiter_moe_enabled
:
from
.rocm_aiter_fused_moe
import
rocm_aiter_fused_experts
self
.
rocm_aiter_fused_experts
=
rocm_aiter_fused_experts
else
:
self
.
rocm_aiter_fused_experts
=
None
# type: ignore
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self
.
flashinfer_cutlass_moe_enabled
=
(
has_flashinfer_cutlass_fused_moe
()
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP16
and
self
.
moe
.
moe_parallel_config
.
use_ep
and
self
.
moe
.
moe_parallel_config
.
dp_size
==
1
and
current_platform
.
get_device_capability
()[
0
]
>=
9
)
if
self
.
flashinfer_cutlass_moe_enabled
:
logger
.
info_once
(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from
functools
import
partial
from
.flashinfer_cutlass_moe
import
flashinfer_cutlass_moe
self
.
flashinfer_cutlass_moe
=
partial
(
flashinfer_cutlass_moe
,
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
,
tp_rank
=
self
.
moe
.
moe_parallel_config
.
tp_rank
,
tp_size
=
self
.
moe
.
moe_parallel_config
.
tp_size
,
ep_rank
=
self
.
moe
.
moe_parallel_config
.
ep_rank
,
ep_size
=
self
.
moe
.
moe_parallel_config
.
ep_size
,
)
else
:
if
(
self
.
moe
.
moe_parallel_config
.
use_ep
and
self
.
moe
.
moe_parallel_config
.
dp_size
==
1
):
logger
.
info_once
(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it."
,
scope
=
"local"
,
)
elif
self
.
moe
.
moe_parallel_config
.
dp_size
>
1
:
logger
.
info_once
(
"FlashInfer CUTLASS MoE is currently not available for DP."
,
scope
=
"local"
,
)
self
.
flashinfer_cutlass_moe
=
None
# type: ignore
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
@
property
def
allow_inplace
(
self
)
->
bool
:
return
True
def
maybe_make_prepare_finalize
(
self
)
->
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
rocm_aiter_moe_enabled
:
return
None
else
:
return
super
().
maybe_make_prepare_finalize
()
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
assert
self
.
moe_quant_config
is
not
None
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
logger
.
debug
(
"BatchedTritonExperts %s"
,
self
.
moe
)
return
BatchedTritonExperts
(
max_num_tokens
=
self
.
moe
.
max_num_tokens
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
)
else
:
logger
.
debug
(
"TritonExperts %s"
,
self
.
moe
)
return
TritonExperts
(
self
.
moe_quant_config
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
moe
.
is_act_and_mul
:
w13_up_dim
=
2
*
intermediate_size_per_partition
else
:
w13_up_dim
=
intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w13_up_dim
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
if
self
.
moe
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
w13_up_dim
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
moe
.
has_bias
:
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
def
_maybe_pad_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if
(
envs
.
VLLM_ROCM_MOE_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
# Padding the weight for better performance on ROCm
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
if
self
.
rocm_aiter_moe_enabled
:
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
.
data
=
shuffled_w13
layer
.
w2_weight
.
data
=
shuffled_w2
if
self
.
flashinfer_cutlass_moe_enabled
:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w
,
w3_w
=
torch
.
chunk
(
layer
.
w13_weight
.
data
,
2
,
dim
=
1
)
w13_weight_swapped
=
torch
.
cat
([
w3_w
,
w1_w
],
dim
=
1
)
layer
.
w13_weight
.
data
=
w13_weight_swapped
.
contiguous
()
if
current_platform
.
is_xpu
():
import
intel_extension_for_pytorch
as
ipex
ep_rank_start
=
self
.
moe
.
ep_rank
*
self
.
moe
.
num_local_experts
layer
.
ipex_fusion
=
ipex
.
llm
.
modules
.
GatedMLPMOE
(
layer
.
w13_weight
,
layer
.
w2_weight
,
use_prepack
=
True
,
experts_start_id
=
ep_rank_start
,
)
elif
current_platform
.
is_cpu
():
from
vllm.model_executor.layers.fused_moe
import
cpu_fused_moe
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
dtype_w13
=
layer
.
w13_weight
.
dtype
_
,
n_w13
,
k_w13
=
layer
.
w13_weight
.
size
()
dtype_w2
=
layer
.
w2_weight
.
dtype
_
,
n_w2
,
k_w2
=
layer
.
w2_weight
.
size
()
if
(
envs
.
VLLM_CPU_SGL_KERNEL
and
check_cpu_sgl_kernel
(
n_w13
,
k_w13
,
dtype_w13
)
and
check_cpu_sgl_kernel
(
n_w2
,
k_w2
,
dtype_w2
)
):
packed_w13_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
w13_weight
)
assert
packed_w13_weight
.
size
()
==
layer
.
w13_weight
.
size
()
layer
.
w13_weight
.
copy_
(
packed_w13_weight
)
del
packed_w13_weight
packed_w2_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
w2_weight
)
assert
packed_w2_weight
.
size
()
==
layer
.
w2_weight
.
size
()
layer
.
w2_weight
.
copy_
(
packed_w2_weight
)
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
SGLFusedMOE
(
layer
)
else
:
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
IPEXFusedMOE
(
layer
)
else
:
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
enable_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
assert
isinstance
(
layer
,
FusedMoE
)
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
enable_eplb
=
enable_eplb
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
moe
.
has_bias
:
return
biased_moe_quant_config
(
layer
.
w13_bias
,
layer
.
w2_bias
,
)
else
:
return
FUSED_MOE_UNQUANTIZED_CONFIG
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
zero_expert_num
=
getattr
(
layer
,
"zero_expert_num"
,
0
)
zero_expert_type
=
getattr
(
layer
,
"zero_expert_type"
,
None
)
topk_weights
,
topk_ids
,
zero_expert_result
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
global_num_experts
=
global_num_experts
,
zero_expert_num
=
zero_expert_num
,
zero_expert_type
=
zero_expert_type
,
num_fused_shared_experts
=
layer
.
num_fused_shared_experts
,
)
if
self
.
rocm_aiter_moe_enabled
:
result
=
self
.
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
expert_map
=
expert_map
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
self
.
flashinfer_cutlass_moe_enabled
:
return
self
.
flashinfer_cutlass_moe
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
else
:
result
=
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
quant_config
=
self
.
moe_quant_config
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
)
if
zero_expert_num
!=
0
and
zero_expert_type
is
not
None
:
assert
not
isinstance
(
result
,
tuple
),
(
"Shared + zero experts are mutually exclusive not yet supported"
)
return
result
,
zero_expert_result
else
:
return
result
def
forward_cpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
enable_eplb
is
not
False
or
expert_load_view
is
not
None
or
logical_to_physical_map
is
not
None
or
logical_replica_count
is
not
None
):
raise
NotImplementedError
(
"Expert load balancing is not supported for CPU."
)
return
layer
.
cpu_fused_moe
(
layer
,
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
global_num_experts
,
expert_map
,
custom_routing_function
,
scoring_func
,
routed_scaling_factor
,
e_score_correction_bias
,
apply_router_weight_on_input
,
activation
,
)
def
forward_xpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
enable_eplb
is
not
False
or
expert_load_view
is
not
None
or
logical_to_physical_map
is
not
None
or
logical_replica_count
is
not
None
):
raise
NotImplementedError
(
"Expert load balancing is not supported for XPU."
)
return
layer
.
ipex_fusion
(
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
)
def
forward_tpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
assert
apply_router_weight_on_input
is
False
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for TPU."
)
if
e_score_correction_bias
is
not
None
:
raise
NotImplementedError
(
"Expert score correction bias is not supported for TPU."
)
assert
activation
==
"silu"
,
f
"
{
activation
}
is not supported for TPU."
assert
routed_scaling_factor
==
1.0
,
(
f
"routed_scaling_factor
{
routed_scaling_factor
}
is not supported for TPU."
)
if
(
enable_eplb
is
not
False
or
expert_load_view
is
not
None
or
logical_to_physical_map
is
not
None
or
logical_replica_count
is
not
None
):
raise
NotImplementedError
(
"Expert load balancing is not supported for TPU."
)
return
fused_moe_pallas
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk
=
top_k
,
gating_output
=
router_logits
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
renormalize
=
renormalize
,
)
if
current_platform
.
is_tpu
():
forward_native
=
forward_tpu
elif
current_platform
.
is_cpu
():
forward_native
=
forward_cpu
elif
current_platform
.
is_xpu
():
forward_native
=
forward_xpu
else
:
forward_native
=
forward_cuda
def
determine_expert_map
(
ep_size
:
int
,
ep_rank
:
int
,
...
...
@@ -1125,16 +233,13 @@ def maybe_roundup_hidden_size(
Rounded up hidden_size if rounding up is required based on the configs.
Original hidden size otherwise.
"""
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_roundup_layer_hidden_size
,
)
if
moe_parallel_config
.
use_deepep_ht_kernels
:
hidden_size
=
DeepEPHTPrepareAndFinalize
.
maybe_roundup_layer_hidden_size
(
hidden_size
,
act_dtype
)
if
moe_parallel_config
.
use_deepep_ll_kernels
:
hidden_size
=
DeepEPLLPrepareAndFinalize
.
maybe_roundup_layer_hidden_size
(
hidden_size
)
hidden_size
=
maybe_roundup_layer_hidden_size
(
hidden_size
,
act_dtype
,
moe_parallel_config
)
# we are padding globally so EP buffer allocation works
if
quant_config
and
quant_config
.
get_name
()
==
"mxfp4"
:
...
...
@@ -1430,7 +535,6 @@ class FusedMoE(CustomOp):
is_lora_enabled
=
vllm_config
.
lora_config
is
not
None
,
)
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
self
.
quant_config
=
quant_config
def
_get_quant_method
()
->
FusedMoEMethodBase
:
...
...
@@ -1508,9 +612,15 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method.
def
maybe_init_modular_kernel
(
self
)
->
None
:
mk
=
self
.
quant_method
.
maybe_init_modular_kernel
(
self
)
if
mk
is
not
None
:
self
.
quant_method
=
FusedMoEModularMethod
(
self
.
quant_method
,
mk
)
self
.
ensure_moe_quant_config_init
()
prepare_finalize
=
self
.
quant_method
.
maybe_make_prepare_finalize
()
if
prepare_finalize
is
not
None
:
logger
.
debug
(
"%s for %s(%s)"
,
prepare_finalize
.
__class__
.
__name__
,
self
,
id
(
self
)
)
self
.
quant_method
=
FusedMoEModularMethod
.
make
(
self
,
self
.
quant_method
,
prepare_finalize
,
self
.
shared_experts
)
@
property
def
shared_experts
(
self
)
->
torch
.
nn
.
Module
|
None
:
...
...
@@ -2142,12 +1252,16 @@ class FusedMoE(CustomOp):
def
ensure_moe_quant_config_init
(
self
):
if
self
.
quant_method
.
moe_quant_config
is
None
:
# Note: the moe_quant_config can't be constructed until after
# weight loading post processing.
self
.
quant_method
.
moe_quant_config
=
(
self
.
quant_method
.
get_fused_moe_quant_config
(
self
)
)
if
self
.
moe_quant_config
is
None
:
self
.
moe_quant_config
=
self
.
quant_method
.
moe_quant_config
@
property
def
moe_quant_config
(
self
)
->
FusedMoEQuantConfig
|
None
:
self
.
ensure_moe_quant_config_init
()
return
self
.
quant_method
.
moe_quant_config
def
ensure_dp_chunking_init
(
self
):
if
not
self
.
use_dp_chunking
or
self
.
batched_hidden_states
is
not
None
:
...
...
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
a1448b4b
...
...
@@ -38,7 +38,7 @@ class SharedFusedMoE(FusedMoE):
and
not
(
# TODO(wentao): find the root cause and remove this condition
self
.
enable_eplb
or
(
self
.
use_flashinfer_cutlass_kernels
and
self
.
dp_size
>
1
)
or
(
self
.
moe_config
.
use_flashinfer_cutlass_kernels
and
self
.
dp_size
>
1
)
)
and
self
.
_shared_experts
is
not
None
)
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
0 → 100644
View file @
a1448b4b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
biased_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEActivationFormat
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
if
current_platform
.
is_cuda_alike
():
from
.fused_batched_moe
import
BatchedTritonExperts
from
.fused_moe
import
TritonExperts
,
fused_experts
else
:
fused_experts
=
None
# type: ignore
if
current_platform
.
is_tpu
():
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
else
:
fused_moe_pallas
=
None
# type: ignore
logger
=
init_logger
(
__name__
)
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
self
.
rocm_aiter_moe_enabled
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
if
self
.
rocm_aiter_moe_enabled
:
from
.rocm_aiter_fused_moe
import
rocm_aiter_fused_experts
self
.
rocm_aiter_fused_experts
=
rocm_aiter_fused_experts
else
:
self
.
rocm_aiter_fused_experts
=
None
# type: ignore
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self
.
flashinfer_cutlass_moe_enabled
=
(
has_flashinfer_cutlass_fused_moe
()
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP16
and
self
.
moe
.
moe_parallel_config
.
use_ep
and
self
.
moe
.
moe_parallel_config
.
dp_size
==
1
and
current_platform
.
get_device_capability
()[
0
]
>=
9
)
if
self
.
flashinfer_cutlass_moe_enabled
:
logger
.
info_once
(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from
functools
import
partial
from
.flashinfer_cutlass_moe
import
flashinfer_cutlass_moe
self
.
flashinfer_cutlass_moe
=
partial
(
flashinfer_cutlass_moe
,
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
,
tp_rank
=
self
.
moe
.
moe_parallel_config
.
tp_rank
,
tp_size
=
self
.
moe
.
moe_parallel_config
.
tp_size
,
ep_rank
=
self
.
moe
.
moe_parallel_config
.
ep_rank
,
ep_size
=
self
.
moe
.
moe_parallel_config
.
ep_size
,
)
else
:
if
(
self
.
moe
.
moe_parallel_config
.
use_ep
and
self
.
moe
.
moe_parallel_config
.
dp_size
==
1
):
logger
.
info_once
(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it."
,
scope
=
"local"
,
)
elif
self
.
moe
.
moe_parallel_config
.
dp_size
>
1
:
logger
.
info_once
(
"FlashInfer CUTLASS MoE is currently not available for DP."
,
scope
=
"local"
,
)
self
.
flashinfer_cutlass_moe
=
None
# type: ignore
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
@
property
def
allow_inplace
(
self
)
->
bool
:
return
True
def
maybe_make_prepare_finalize
(
self
)
->
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
rocm_aiter_moe_enabled
:
return
None
else
:
return
super
().
maybe_make_prepare_finalize
()
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
assert
self
.
moe_quant_config
is
not
None
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
logger
.
debug
(
"BatchedTritonExperts %s"
,
self
.
moe
)
return
BatchedTritonExperts
(
max_num_tokens
=
self
.
moe
.
max_num_tokens
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
)
else
:
logger
.
debug
(
"TritonExperts %s"
,
self
.
moe
)
return
TritonExperts
(
self
.
moe_quant_config
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
moe
.
is_act_and_mul
:
w13_up_dim
=
2
*
intermediate_size_per_partition
else
:
w13_up_dim
=
intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
w13_up_dim
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
if
self
.
moe
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
w13_up_dim
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
moe
.
has_bias
:
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
def
_maybe_pad_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if
(
envs
.
VLLM_ROCM_MOE_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
# Padding the weight for better performance on ROCm
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
if
self
.
rocm_aiter_moe_enabled
:
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
.
data
=
shuffled_w13
layer
.
w2_weight
.
data
=
shuffled_w2
if
self
.
flashinfer_cutlass_moe_enabled
:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w
,
w3_w
=
torch
.
chunk
(
layer
.
w13_weight
.
data
,
2
,
dim
=
1
)
w13_weight_swapped
=
torch
.
cat
([
w3_w
,
w1_w
],
dim
=
1
)
layer
.
w13_weight
.
data
=
w13_weight_swapped
.
contiguous
()
if
current_platform
.
is_xpu
():
import
intel_extension_for_pytorch
as
ipex
ep_rank_start
=
self
.
moe
.
ep_rank
*
self
.
moe
.
num_local_experts
layer
.
ipex_fusion
=
ipex
.
llm
.
modules
.
GatedMLPMOE
(
layer
.
w13_weight
,
layer
.
w2_weight
,
use_prepack
=
True
,
experts_start_id
=
ep_rank_start
,
)
elif
current_platform
.
is_cpu
():
from
vllm.model_executor.layers.fused_moe
import
cpu_fused_moe
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
dtype_w13
=
layer
.
w13_weight
.
dtype
_
,
n_w13
,
k_w13
=
layer
.
w13_weight
.
size
()
dtype_w2
=
layer
.
w2_weight
.
dtype
_
,
n_w2
,
k_w2
=
layer
.
w2_weight
.
size
()
if
(
envs
.
VLLM_CPU_SGL_KERNEL
and
check_cpu_sgl_kernel
(
n_w13
,
k_w13
,
dtype_w13
)
and
check_cpu_sgl_kernel
(
n_w2
,
k_w2
,
dtype_w2
)
):
packed_w13_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
w13_weight
)
assert
packed_w13_weight
.
size
()
==
layer
.
w13_weight
.
size
()
layer
.
w13_weight
.
copy_
(
packed_w13_weight
)
del
packed_w13_weight
packed_w2_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
w2_weight
)
assert
packed_w2_weight
.
size
()
==
layer
.
w2_weight
.
size
()
layer
.
w2_weight
.
copy_
(
packed_w2_weight
)
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
SGLFusedMOE
(
layer
)
else
:
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
IPEXFusedMOE
(
layer
)
else
:
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
enable_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
enable_eplb
=
enable_eplb
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
moe
.
has_bias
:
return
biased_moe_quant_config
(
layer
.
w13_bias
,
layer
.
w2_bias
,
)
else
:
return
FUSED_MOE_UNQUANTIZED_CONFIG
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
zero_expert_num
=
getattr
(
layer
,
"zero_expert_num"
,
0
)
zero_expert_type
=
getattr
(
layer
,
"zero_expert_type"
,
None
)
topk_weights
,
topk_ids
,
zero_expert_result
=
layer
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
global_num_experts
=
global_num_experts
,
zero_expert_num
=
zero_expert_num
,
zero_expert_type
=
zero_expert_type
,
num_fused_shared_experts
=
layer
.
num_fused_shared_experts
,
)
if
self
.
rocm_aiter_moe_enabled
:
result
=
self
.
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
expert_map
=
expert_map
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
self
.
flashinfer_cutlass_moe_enabled
:
return
self
.
flashinfer_cutlass_moe
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
else
:
result
=
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
quant_config
=
self
.
moe_quant_config
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
)
if
zero_expert_num
!=
0
and
zero_expert_type
is
not
None
:
assert
not
isinstance
(
result
,
tuple
),
(
"Shared + zero experts are mutually exclusive not yet supported"
)
return
result
,
zero_expert_result
else
:
return
result
def
forward_cpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
enable_eplb
is
not
False
or
expert_load_view
is
not
None
or
logical_to_physical_map
is
not
None
or
logical_replica_count
is
not
None
):
raise
NotImplementedError
(
"Expert load balancing is not supported for CPU."
)
return
layer
.
cpu_fused_moe
(
layer
,
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
global_num_experts
,
expert_map
,
custom_routing_function
,
scoring_func
,
routed_scaling_factor
,
e_score_correction_bias
,
apply_router_weight_on_input
,
activation
,
)
def
forward_xpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
enable_eplb
is
not
False
or
expert_load_view
is
not
None
or
logical_to_physical_map
is
not
None
or
logical_replica_count
is
not
None
):
raise
NotImplementedError
(
"Expert load balancing is not supported for XPU."
)
return
layer
.
ipex_fusion
(
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
)
def
forward_tpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
assert
apply_router_weight_on_input
is
False
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for TPU."
)
if
e_score_correction_bias
is
not
None
:
raise
NotImplementedError
(
"Expert score correction bias is not supported for TPU."
)
assert
activation
==
"silu"
,
f
"
{
activation
}
is not supported for TPU."
assert
routed_scaling_factor
==
1.0
,
(
f
"routed_scaling_factor
{
routed_scaling_factor
}
is not supported for TPU."
)
if
(
enable_eplb
is
not
False
or
expert_load_view
is
not
None
or
logical_to_physical_map
is
not
None
or
logical_replica_count
is
not
None
):
raise
NotImplementedError
(
"Expert load balancing is not supported for TPU."
)
return
fused_moe_pallas
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk
=
top_k
,
gating_output
=
router_logits
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
renormalize
=
renormalize
,
)
if
current_platform
.
is_tpu
():
forward_native
=
forward_tpu
elif
current_platform
.
is_cpu
():
forward_native
=
forward_cpu
elif
current_platform
.
is_xpu
():
forward_native
=
forward_xpu
else
:
forward_native
=
forward_cuda
vllm/model_executor/layers/quantization/mxfp4.py
View file @
a1448b4b
...
...
@@ -741,15 +741,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
weight_scale
=
w2_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
)
)
self
.
w13_weight_triton_tensor
=
w13_weight
self
.
w2_weight_triton_tensor
=
w2_weight
# need to delete the original weights to save memory on single GPU
del
layer
.
w13_weight
del
layer
.
w2_weight
layer
.
w13_weight
=
None
layer
.
w2_weight
=
None
torch
.
cuda
.
empty_cache
()
self
.
w13_weight
=
w13_weight
self
.
w2_weight
=
w2_weight
layer
.
w13_weight
=
w13_weight
layer
.
w2_weight
=
w2_weight
else
:
raise
ValueError
(
f
"Unsupported backend:
{
self
.
mxfp4_backend
}
"
)
...
...
@@ -824,18 +819,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"EP batched experts format"
)
else
:
layer
.
w13_weight
=
(
self
.
w13_weight_triton_tensor
if
layer
.
w13_weight
is
None
else
layer
.
w13_weight
)
layer
.
w2_weight
=
(
self
.
w2_weight_triton_tensor
if
layer
.
w2_weight
is
None
else
layer
.
w2_weight
)
assert
all
([
w
is
not
None
for
w
in
[
layer
.
w13_weight
,
layer
.
w2_weight
]])
assert
self
.
moe_quant_config
is
not
None
if
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
...
...
@@ -1070,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
self
.
w13_weight
_triton_tensor
,
w2
=
self
.
w2_weight
_triton_tensor
,
w1
=
self
.
w13_weight
,
w2
=
self
.
w2_weight
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment