Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0da93439
Commit
0da93439
authored
Mar 26, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori
parents
25f2f756
298e5108
Changes
613
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1618 additions
and
477 deletions
+1618
-477
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+5
-0
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+180
-24
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-28
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
.../model_executor/layers/fused_moe/mori_prepare_finalize.py
+3
-6
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+16
-1
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+847
-0
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
+66
-23
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+0
-11
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+19
-7
vllm/model_executor/layers/fused_moe/router/gate_linear.py
vllm/model_executor/layers/fused_moe/router/gate_linear.py
+52
-6
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
...el_executor/layers/fused_moe/runner/default_moe_runner.py
+381
-296
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+3
-23
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+1
-1
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+1
-1
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+1
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+1
-1
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+3
-3
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+1
-1
vllm/model_executor/layers/pooler/activations.py
vllm/model_executor/layers/pooler/activations.py
+10
-22
vllm/model_executor/layers/pooler/seqwise/heads.py
vllm/model_executor/layers/pooler/seqwise/heads.py
+26
-22
No files found.
Too many changes to show.
To preserve performance only
613 of 613+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
View file @
0da93439
...
...
@@ -101,6 +101,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return
self
.
moe_kernel
.
prepare_finalize
.
topk_indices_dtype
()
return
None
@
property
def
skip_forward_padding
(
self
)
->
bool
:
"""Whether to skip the padding in the forward before applying the moe method."""
return
False
@
property
def
supports_eplb
(
self
)
->
bool
:
return
False
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
0da93439
...
...
@@ -11,8 +11,10 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
...
...
@@ -20,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kMxfp4Static
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -142,6 +145,33 @@ def legacy_routing_from_bitmatrix(
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing_from_sparsematrix
(
sparse_logits
:
"SparseMatrix"
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing
(
logits
:
torch
.
Tensor
,
n_expts_act
:
int
,
...
...
@@ -158,10 +188,8 @@ def legacy_routing(
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sparse_logits
=
topk
(
logits
,
n_expts_act
,
apply_softmax
=
not
sm_first
)
return
legacy_routing_from_bitmatrix
(
sparse_logits
.
mask
,
sparse_logits
.
vals
,
sparse_logits
.
indx
,
return
legacy_routing_from_sparsematrix
(
sparse_logits
,
logits
.
shape
[
-
1
],
n_expts_act
,
)
...
...
@@ -512,43 +540,43 @@ def make_routing_data(
class
BaseOAITritonExperts
(
mk
.
FusedMoEExpertsModular
):
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
@
staticmethod
def
_supports_current_device
()
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
p
=
current_platform
if
not
p
.
is_cuda_alike
():
return
False
cap
=
p
.
get_device_capability
()
if
cap
is
None
:
return
False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return
(
9
,
0
)
<=
(
cap
.
major
,
cap
.
minor
)
<
(
11
,
0
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
SUPPORTED_W_A
=
[
(
kMxfp4Static
,
None
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
raise
NotImplementedError
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
...
...
@@ -605,6 +633,10 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
class
OAITritonExperts
(
BaseOAITritonExperts
):
"""OAI Triton-based fused MoE expert implementation."""
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
return
activation
==
MoEActivation
.
SWIGLUOAI
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
...
...
@@ -689,6 +721,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
GELU
,
MoEActivation
.
SWIGLUOAI
,
MoEActivation
.
SWIGLUSTEP
,
]
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
...
...
@@ -814,3 +855,118 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
)
self
.
moe_sum
(
intermediate_cache3
.
view
(
-
1
,
topk
,
K
),
output
)
class
OAITritonMxfp4ExpertsMonolithic
(
mk
.
FusedMoEExpertsMonolithic
):
"""Monolithic Triton MXFP4 expert. Wraps triton_kernel_moe_forward()."""
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
moe_config
,
quant_config
)
self
.
topk
=
moe_config
.
experts_per_token
self
.
renormalize
=
moe_config
.
routing_method
in
(
RoutingMethodType
.
Renormalize
,
RoutingMethodType
.
RenormalizeNaive
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_current_device
()
->
bool
:
p
=
current_platform
if
not
p
.
is_cuda_alike
():
return
False
cap
=
p
.
get_device_capability
()
if
cap
is
None
:
return
False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return
(
9
,
0
)
<=
(
cap
.
major
,
cap
.
minor
)
<
(
11
,
0
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[
(
kMxfp4Static
,
None
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
return
activation
==
MoEActivation
.
SWIGLUOAI
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
bool
:
return
(
not
moe_parallel_config
.
use_all2all_kernels
and
not
moe_parallel_config
.
enable_eplb
and
moe_parallel_config
.
dp_size
<=
1
)
@
staticmethod
def
_supports_routing_method
(
routing_method
:
RoutingMethodType
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
return
routing_method
in
[
RoutingMethodType
.
Renormalize
,
RoutingMethodType
.
RenormalizeNaive
,
]
@
staticmethod
def
_supports_router_logits_dtype
(
router_logits_dtype
:
torch
.
dtype
|
None
,
routing_method
:
RoutingMethodType
,
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
activation
:
MoEActivation
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
# grouped topk + fused topk bias parameters
num_expert_group
:
int
|
None
=
None
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
routed_scaling_factor
:
float
|
None
=
None
,
topk_group
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
return
triton_kernel_moe_forward
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
gating_output
=
router_logits
,
topk
=
self
.
topk
,
renormalize
=
self
.
renormalize
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
quant_config
=
self
.
quant_config
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
0da93439
...
...
@@ -52,7 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
...
...
@@ -218,7 +217,6 @@ def maybe_roundup_hidden_size(
moe_parallel_config
:
FusedMoEParallelConfig
,
is_lora_enabled
:
bool
,
model_type
:
str
|
None
,
is_mxfp4_quant
:
bool
,
)
->
int
:
"""
Given layer hidden size and MoE configurations, round up hidden_size
...
...
@@ -232,7 +230,6 @@ def maybe_roundup_hidden_size(
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return:
Rounded up hidden_size if rounding up is required based on the configs.
...
...
@@ -246,28 +243,6 @@ def maybe_roundup_hidden_size(
hidden_size
,
act_dtype
,
moe_parallel_config
)
# we are padding globally so EP buffer allocation works
if
model_type
==
"gpt_oss"
and
is_mxfp4_quant
:
from
vllm.model_executor.layers.quantization.mxfp4
import
(
Mxfp4Backend
,
get_mxfp4_backend
,
)
current_mxfp4_backend
=
get_mxfp4_backend
(
is_lora_enabled
)
if
(
current_mxfp4_backend
==
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
or
current_mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
):
hidden_size
=
round_up
(
hidden_size
,
128
)
elif
(
current_platform
.
is_rocm
()
or
current_mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
or
current_mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
or
current_mxfp4_backend
==
Mxfp4Backend
.
MARLIN
):
hidden_size
=
round_up
(
hidden_size
,
256
)
return
hidden_size
...
...
@@ -504,6 +479,8 @@ class FusedMoE(CustomOp):
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
self
.
activation
=
MoEActivation
.
from_str
(
activation
)
# TODO(bnell): we should not have to create a router if the kernel is
# monolithic.
self
.
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
self
.
global_num_experts
,
...
...
@@ -538,9 +515,6 @@ class FusedMoE(CustomOp):
moe_parallel_config
=
self
.
moe_parallel_config
,
is_lora_enabled
=
vllm_config
.
lora_config
is
not
None
,
model_type
=
self
.
model_type
,
is_mxfp4_quant
=
(
quant_config
is
not
None
and
quant_config
.
is_mxfp4_quant
(
prefix
,
self
)
),
)
self
.
hidden_size
=
hidden_size
...
...
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
View file @
0da93439
...
...
@@ -70,16 +70,13 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
if
defer_input_quant
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert
not
apply_router_weight_on_input
,
(
"mori does not support apply_router_weight_on_input=True now."
)
scale
=
None
if
self
.
use_fp8_dispatch
:
# When defer_input_quant is True, the expert kernel handles
# quantization internally, so skip FP8 dispatch quantization.
if
self
.
use_fp8_dispatch
and
not
defer_input_quant
:
from
aiter
import
QuantType
,
get_hip_quant
if
quant_config
.
is_block_quantized
:
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
0da93439
...
...
@@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
]:
w13
,
w2
,
w13_scale
=
prepare_fp8_moe_layer_for_fi
(
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_fi
(
layer
=
layer
,
w13
=
w13
,
w2
=
w2
,
...
...
@@ -512,6 +512,21 @@ def make_fp8_moe_quant_config(
g1_alphas
=
(
w1_scale
*
a1_scale
).
squeeze
(),
g2_alphas
=
(
w2_scale
*
a2_scale
).
squeeze
(),
)
# MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to
# _mxfp8_e4m3_quantize rather than standard FP8 block quantization.
# Non-swizzled layout is required since the TRTLLM kernel expects
# scales in (num_tokens, hidden_dim // 32) format.
if
block_shape
==
[
1
,
32
]:
return
FusedMoEQuantConfig
.
make
(
"mxfp8"
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
is_nvfp4_scale_swizzled
=
False
,
)
# All other backends use normal config.
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
0 → 100644
View file @
0da93439
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
from
typing
import
Union
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoEConfig
,
)
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
mxfp4_mxfp8_moe_quant_config
,
mxfp4_w4a16_moe_quant_config
,
ocp_mx_moe_quant_config
,
)
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
_swizzle_mxfp4
,
get_padding_alignment
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kMxfp4Static
,
kMxfp8Dynamic
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.import_utils
import
has_triton_kernels
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
if
has_triton_kernels
():
try
:
from
triton_kernels.matmul_ogs
import
PrecisionConfig
except
(
ImportError
,
AttributeError
)
as
e
:
logger
.
error
(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s"
,
e
,
)
class
Mxfp4MoeBackend
(
Enum
):
NONE
=
"None"
# FlashInfer TRTLLM backends
FLASHINFER_TRTLLM_MXFP4_MXFP8
=
"FLASHINFER_TRTLLM_MXFP4_MXFP8"
FLASHINFER_TRTLLM_MXFP4_BF16
=
"FLASHINFER_TRTLLM_MXFP4_BF16"
# FlashInfer CUTLASS backends
FLASHINFER_CUTLASS_MXFP4_MXFP8
=
"FLASHINFER_CUTLASS_MXFP4_MXFP8"
FLASHINFER_CUTLASS_MXFP4_BF16
=
"FLASHINFER_CUTLASS_MXFP4_BF16"
# Marlin
BATCHED_MARLIN
=
"BATCHED_MARLIN"
MARLIN
=
"MARLIN"
# ROCm AITER (CK)
CK
=
"CK"
# Triton
TRITON
=
"TRITON"
TRITON_UNFUSED
=
"TRITON_UNFUSED"
# XPU
XPU
=
"XPU"
# Backends that share the same TRTLLM weight format
TRTLLM_BACKENDS
=
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
,
)
TRITON_BACKENDS
=
(
Mxfp4MoeBackend
.
TRITON
,
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
)
def
backend_to_kernel_cls
(
backend
:
Mxfp4MoeBackend
,
)
->
list
[
type
[
mk
.
FusedMoEExperts
]]:
if
backend
in
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
,
):
from
vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe
import
(
TrtLlmMxfp4ExpertsModular
,
TrtLlmMxfp4ExpertsMonolithic
,
)
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
return
[
TrtLlmMxfp4ExpertsMonolithic
,
TrtLlmMxfp4ExpertsModular
]
elif
backend
in
(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
):
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
return
[
FlashInferExperts
]
elif
backend
==
Mxfp4MoeBackend
.
TRITON
:
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
OAITritonExperts
,
OAITritonMxfp4ExpertsMonolithic
,
)
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
return
[
OAITritonMxfp4ExpertsMonolithic
,
OAITritonExperts
]
elif
backend
==
Mxfp4MoeBackend
.
TRITON_UNFUSED
:
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
UnfusedOAITritonExperts
,
)
return
[
UnfusedOAITritonExperts
]
elif
backend
==
Mxfp4MoeBackend
.
MARLIN
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
return
[
MarlinExperts
]
elif
backend
==
Mxfp4MoeBackend
.
BATCHED_MARLIN
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
BatchedMarlinExperts
,
)
return
[
BatchedMarlinExperts
]
elif
backend
==
Mxfp4MoeBackend
.
CK
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
AiterExperts
,
)
return
[
AiterExperts
]
elif
backend
==
Mxfp4MoeBackend
.
XPU
:
raise
NotImplementedError
(
"XPU backend uses XpuMxfp4MoEMethod directly."
)
else
:
raise
ValueError
(
f
"Unknown MXFP4 MoE backend:
{
backend
.
value
}
"
)
def
map_mxfp4_backend
(
runner_backend
:
str
)
->
Mxfp4MoeBackend
:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping
:
dict
[
str
,
Mxfp4MoeBackend
]
=
{
"flashinfer_trtllm"
:
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
"flashinfer_trtllm_afp8"
:
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
,
"flashinfer_cutlass"
:
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
"flashinfer_cutlass_afp8"
:
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
"triton"
:
Mxfp4MoeBackend
.
TRITON
,
"marlin"
:
Mxfp4MoeBackend
.
MARLIN
,
"ck"
:
Mxfp4MoeBackend
.
CK
,
}
if
backend
:
=
mapping
.
get
(
runner_backend
):
return
backend
raise
ValueError
(
f
"moe_backend='
{
runner_backend
}
' is not supported for MXFP4 MoE. "
f
"Expected one of
{
list
(
mapping
.
keys
())
}
."
)
def
_get_priority_backends
()
->
list
[
Mxfp4MoeBackend
]:
"""
Get available backends in priority order based on platform and config.
Only includes BF16 backends. MXFP8 backends are selected via env vars.
"""
_AVAILABLE_BACKENDS
=
[
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
Mxfp4MoeBackend
.
CK
,
Mxfp4MoeBackend
.
TRITON
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
Mxfp4MoeBackend
.
MARLIN
,
Mxfp4MoeBackend
.
BATCHED_MARLIN
,
]
return
_AVAILABLE_BACKENDS
def
_backend_activation_key
(
backend
:
Mxfp4MoeBackend
)
->
QuantKey
|
None
:
"""Map backend to its activation key (MXFP8 or None for BF16)."""
if
backend
in
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
):
return
kMxfp8Dynamic
return
None
def
select_mxfp4_moe_backend
(
config
:
FusedMoEConfig
,
)
->
tuple
[
Mxfp4MoeBackend
,
type
[
mk
.
FusedMoEExperts
]
|
None
]:
"""
Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported
=
has_triton_kernels
()
and
(
9
,
0
,
)
<=
current_platform
.
get_device_capability
()
<
(
11
,
0
)
# LoRA: separate experts backend path
if
config
.
is_lora_enabled
:
if
not
current_platform
.
is_cuda
():
raise
NotImplementedError
(
"Mxfp4 LoRA only supported on CUDA Platform."
)
if
envs
.
VLLM_MXFP4_USE_MARLIN
is
False
and
triton_kernels_supported
:
logger
.
info_once
(
"Using Triton backend for mxfp4 lora"
)
return
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
backend_to_kernel_cls
(
Mxfp4MoeBackend
.
TRITON_UNFUSED
)[
0
]
logger
.
info_once
(
"Using Marlin backend for mxfp4 lora"
)
return
Mxfp4MoeBackend
.
MARLIN
,
backend_to_kernel_cls
(
Mxfp4MoeBackend
.
MARLIN
)[
0
]
activation_format
=
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
if
config
.
moe_parallel_config
.
use_batched_activation_format
else
mk
.
FusedMoEActivationFormat
.
Standard
)
def
_make_log_backend
(
backend
:
Mxfp4MoeBackend
):
return
f
"Using '
{
backend
.
value
}
' Mxfp4 MoE backend."
def
_make_log_unsupported
(
backend
:
Mxfp4MoeBackend
,
reason
:
str
|
None
)
->
str
:
if
reason
:
return
(
f
"Mxfp4 MoE backend '
{
backend
.
value
}
' does not support the "
f
"deployment configuration since
{
reason
}
."
)
return
(
f
"Mxfp4 MoE backend '
{
backend
.
value
}
' does not support the "
"deployment configuration."
)
def
_return_or_raise
(
backend
:
Mxfp4MoeBackend
,
config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_format
:
mk
.
FusedMoEActivationFormat
,
)
->
tuple
[
Mxfp4MoeBackend
,
type
[
mk
.
FusedMoEExperts
]]:
reason
:
str
|
None
=
None
for
k_cls
in
backend_to_kernel_cls
(
backend
):
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
weight_key
,
activation_key
,
activation_format
)
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
,
k_cls
raise
ValueError
(
_make_log_unsupported
(
backend
,
reason
))
runner_backend
=
config
.
moe_backend
if
runner_backend
!=
"auto"
:
requested_backend
=
map_mxfp4_backend
(
runner_backend
)
if
(
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
and
requested_backend
==
Mxfp4MoeBackend
.
MARLIN
):
requested_backend
=
Mxfp4MoeBackend
.
BATCHED_MARLIN
return
_return_or_raise
(
requested_backend
,
config
,
kMxfp4Static
,
_backend_activation_key
(
requested_backend
),
activation_format
,
)
# Select kernels in order of backend.
AVAILABLE_BACKENDS
=
_get_priority_backends
()
# Handle explicit FlashInfer MXFP4 BF16 configuration.
if
envs
.
is_set
(
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"
):
if
not
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
:
AVAILABLE_BACKENDS
.
remove
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
)
AVAILABLE_BACKENDS
.
remove
(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
)
else
:
if
current_platform
.
is_device_capability
(
90
):
return
_return_or_raise
(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
config
,
kMxfp4Static
,
None
,
activation_format
,
)
if
current_platform
.
is_device_capability_family
(
100
):
return
_return_or_raise
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
config
,
kMxfp4Static
,
None
,
activation_format
,
)
raise
ValueError
(
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 is set but the "
"current device capability is not supported. "
"Only SM90 (CUTLASS) and SM100+ (TRTLLM) are supported."
)
# Handle explicit FlashInfer MXFP4 MXFP8 TRTLLM configuration.
if
(
envs
.
is_set
(
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8"
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
return
_return_or_raise
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
,
config
,
kMxfp4Static
,
kMxfp8Dynamic
,
activation_format
,
)
# Handle explicit FlashInfer MXFP4 MXFP8 CUTLASS configuration.
if
(
envs
.
is_set
(
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS"
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
return
_return_or_raise
(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
config
,
kMxfp4Static
,
kMxfp8Dynamic
,
activation_format
,
)
# Handle explicit Marlin MXFP4 configuration.
if
envs
.
is_set
(
"VLLM_MXFP4_USE_MARLIN"
)
and
envs
.
VLLM_MXFP4_USE_MARLIN
:
return
_return_or_raise
(
Mxfp4MoeBackend
.
MARLIN
,
config
,
kMxfp4Static
,
None
,
activation_format
,
)
for
backend
in
AVAILABLE_BACKENDS
:
activation_key
=
_backend_activation_key
(
backend
)
for
k_cls
in
backend_to_kernel_cls
(
backend
):
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
kMxfp4Static
,
activation_key
,
activation_format
)
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
,
k_cls
else
:
logger
.
debug_once
(
_make_log_unsupported
(
backend
,
reason
),
scope
=
"local"
)
if
current_platform
.
is_xpu
():
backend
=
Mxfp4MoeBackend
.
XPU
logger
.
info_once
(
_make_log_backend
(
backend
))
return
backend
,
None
if
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
():
raise
NotImplementedError
(
"No MXFP4 MoE backend supports the deployment configuration."
)
return
Mxfp4MoeBackend
.
NONE
,
None
def
mxfp4_round_up_hidden_size_and_intermediate_size
(
backend
:
Mxfp4MoeBackend
,
hidden_size
:
int
,
intermediate_size
:
int
)
->
tuple
[
int
,
int
]:
"""Round up hidden_size and intermediate_size based on backend requirements."""
if
backend
in
(
Mxfp4MoeBackend
.
MARLIN
,
Mxfp4MoeBackend
.
BATCHED_MARLIN
):
intermediate_size
=
round_up
(
intermediate_size
,
128
)
if
current_platform
.
is_xpu
():
hidden_size
=
round_up
(
hidden_size
,
128
)
else
:
hidden_size
=
round_up
(
hidden_size
,
256
)
elif
backend
in
TRTLLM_BACKENDS
:
intermediate_size
=
round_up
(
intermediate_size
,
256
)
hidden_size
=
round_up
(
hidden_size
,
256
)
elif
backend
in
(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
):
intermediate_size
=
round_up
(
intermediate_size
,
128
)
hidden_size
=
round_up
(
hidden_size
,
128
)
elif
current_platform
.
is_rocm
():
pad_align
=
get_padding_alignment
()
intermediate_size
=
round_up
(
intermediate_size
,
pad_align
)
hidden_size
=
round_up
(
hidden_size
,
pad_align
)
else
:
intermediate_size
=
round_up
(
intermediate_size
,
64
)
return
hidden_size
,
intermediate_size
def
convert_to_mxfp4_moe_kernel_format
(
mxfp4_backend
:
Mxfp4MoeBackend
,
layer
:
torch
.
nn
.
Module
,
w13_weight
:
torch
.
Tensor
,
w2_weight
:
torch
.
Tensor
,
w13_weight_scale
:
torch
.
Tensor
,
w2_weight_scale
:
torch
.
Tensor
,
w13_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
_cache_permute_indices
:
dict
[
torch
.
Size
,
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]:
"""Convert loaded weights into backend-specific kernel format."""
num_experts
=
w13_weight
.
shape
[
0
]
intermediate_size
=
w13_weight
.
shape
[
1
]
//
2
hidden_size
=
w13_weight
.
shape
[
2
]
*
2
sf_block_size
=
32
# mxfp4 block size
if
mxfp4_backend
in
(
Mxfp4MoeBackend
.
MARLIN
,
Mxfp4MoeBackend
.
BATCHED_MARLIN
):
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
prepare_moe_mxfp4_layer_for_marlin
,
)
return
prepare_moe_mxfp4_layer_for_marlin
(
layer
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
w13_bias
,
w2_bias
,
)
elif
mxfp4_backend
in
TRTLLM_BACKENDS
:
assert
_cache_permute_indices
is
not
None
from
flashinfer.fp4_quantization
import
nvfp4_block_scale_interleave
from
flashinfer.fused_moe.core
import
get_w2_permute_indices_with_cache
# gemm1_alpha/beta/clamp_limit are created by the expert class
# (TrtLlmMxfp4ExpertsBase), not on the layer.
w13_weight
=
w13_weight
.
data
w2_weight
=
w2_weight
.
data
w13_weight_scale
=
w13_weight_scale
.
data
w2_weight_scale
=
w2_weight_scale
.
data
assert
w13_bias
is
not
None
and
w2_bias
is
not
None
w13_bias
=
w13_bias
.
data
.
to
(
torch
.
float32
)
w2_bias
=
w2_bias
.
data
.
to
(
torch
.
float32
)
# Swap w1 and w3 as the definition of swiglu is different in trtllm-gen
def
swap_every_two_rows
(
x
,
axis
=-
1
):
shape
=
x
.
shape
if
axis
<
0
:
axis
=
len
(
shape
)
+
axis
new_shape
=
list
(
shape
)
new_shape
[
axis
]
=
shape
[
axis
]
//
2
new_shape
.
insert
(
axis
+
1
,
2
)
x
=
x
.
reshape
(
*
new_shape
)
x
=
x
.
flip
(
axis
+
1
)
new_shape
=
list
(
shape
)
return
x
.
reshape
(
*
new_shape
)
w13_weight_scale
=
swap_every_two_rows
(
w13_weight_scale
,
-
2
)
w13_weight
=
swap_every_two_rows
(
w13_weight
,
-
2
)
w13_bias
=
swap_every_two_rows
(
w13_bias
,
-
1
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled
=
[]
gemm1_scales_shuffled
=
[]
gemm2_weights_shuffled
=
[]
gemm2_scales_shuffled
=
[]
gemm1_bias_shuffled
=
[]
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
for
i
in
range
(
num_experts
):
# w13 weight
permute_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
)
gemm1_weights_shuffled
.
append
(
w13_weight
[
i
]
.
view
(
torch
.
uint8
)[
permute_indices
.
to
(
w13_weight
.
device
)]
.
contiguous
()
)
# w13 scale
permute_sf_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm1_scales_shuffled
.
append
(
nvfp4_block_scale_interleave
(
w13_weight_scale
[
i
]
.
view
(
torch
.
uint8
)[
permute_sf_indices
.
to
(
w13_weight_scale
.
device
)]
.
contiguous
()
)
)
# w13 bias
permute_bias_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
,
)
gemm1_bias_shuffled
.
append
(
w13_bias
[
i
]
.
clone
()
.
reshape
(
-
1
,
1
)[
permute_bias_indices
.
to
(
w13_bias
.
device
)]
.
contiguous
()
)
# w2 weight
permute_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
)
gemm2_weights_shuffled
.
append
(
w2_weight
[
i
]
.
view
(
torch
.
uint8
)[
permute_indices
.
to
(
w2_weight
.
device
)]
.
contiguous
()
)
# w2 scale
permute_sf_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm2_scales_shuffled
.
append
(
nvfp4_block_scale_interleave
(
w2_weight_scale
[
i
]
.
view
(
torch
.
uint8
)[
permute_sf_indices
.
to
(
w2_weight_scale
.
device
)]
.
contiguous
()
)
)
# w2 bias
permute_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
,
)
gemm2_bias_shuffled
.
append
(
w2_bias
[
i
]
.
clone
()
.
reshape
(
-
1
,
1
)[
permute_indices
.
to
(
w2_bias
.
device
)]
.
contiguous
()
)
w13_weight
=
torch
.
stack
(
gemm1_weights_shuffled
)
w13_weight_scale
=
(
torch
.
stack
(
gemm1_scales_shuffled
)
.
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
sf_block_size
)
.
view
(
torch
.
float8_e4m3fn
)
)
w2_weight
=
torch
.
stack
(
gemm2_weights_shuffled
)
w2_weight_scale
=
(
torch
.
stack
(
gemm2_scales_shuffled
)
.
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
sf_block_size
)
.
view
(
torch
.
float8_e4m3fn
)
)
w13_bias
=
torch
.
stack
(
gemm1_bias_shuffled
).
reshape
(
num_experts
,
-
1
)
w2_bias
=
torch
.
stack
(
gemm2_bias_shuffled
).
reshape
(
num_experts
,
-
1
)
return
(
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
w13_bias
,
w2_bias
,
)
elif
mxfp4_backend
in
(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
):
# De-interleave and swap for w13 weight, bias, and scales
w13_w
=
w13_weight
.
data
gate_w
,
up_w
=
w13_w
[:,
::
2
,
:],
w13_w
[:,
1
::
2
,
:]
deinterleaved_w13_w
=
torch
.
cat
([
gate_w
,
up_w
],
dim
=
1
)
w1_w
,
w3_w
=
torch
.
chunk
(
deinterleaved_w13_w
,
2
,
dim
=
1
)
w13_weight_swapped
=
torch
.
cat
([
w3_w
,
w1_w
],
dim
=
1
)
assert
w13_bias
is
not
None
and
w2_bias
is
not
None
w13_b
=
w13_bias
.
data
.
to
(
torch
.
float32
)
gate_b
,
up_b
=
w13_b
[:,
::
2
],
w13_b
[:,
1
::
2
]
deinterleaved_w13_b
=
torch
.
cat
([
gate_b
,
up_b
],
dim
=
1
)
b1
,
b3
=
torch
.
chunk
(
deinterleaved_w13_b
,
2
,
dim
=-
1
)
w13_bias_swapped
=
torch
.
cat
([
b3
,
b1
],
dim
=-
1
).
to
(
torch
.
bfloat16
)
w13_s
=
w13_weight_scale
.
data
gate_s
,
up_s
=
w13_s
[:,
::
2
,
:],
w13_s
[:,
1
::
2
,
:]
deinterleaved_w13_s
=
torch
.
cat
([
gate_s
,
up_s
],
dim
=
1
)
s1
,
s3
=
torch
.
chunk
(
deinterleaved_w13_s
,
2
,
dim
=
1
)
w13_scale_swapped
=
torch
.
cat
([
s3
,
s1
],
dim
=
1
)
if
mxfp4_backend
==
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
:
from
flashinfer
import
block_scale_interleave
orig_shape
=
w13_scale_swapped
.
shape
w13_scale_interleaved
=
block_scale_interleave
(
w13_scale_swapped
.
view
(
torch
.
uint8
)
).
reshape
(
orig_shape
)
w2_s
=
w2_weight_scale
.
data
orig_shape
=
w2_s
.
shape
w2_scale_interleaved
=
block_scale_interleave
(
w2_s
.
view
(
torch
.
uint8
)
).
reshape
(
orig_shape
)
return
(
w13_weight_swapped
,
w2_weight
,
w13_scale_interleaved
,
w2_scale_interleaved
,
w13_bias_swapped
,
w2_bias
,
)
else
:
assert
mxfp4_backend
==
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
def
_interleave_mxfp4_cutlass_sm90
(
w
):
w_shape
=
w
.
shape
w_interleaved
=
w
.
reshape
(
w_shape
[
0
],
w_shape
[
1
],
(
w_shape
[
2
]
//
4
),
4
)
w_interleaved
=
w_interleaved
.
permute
(
0
,
2
,
1
,
3
)
w_interleaved
=
w_interleaved
.
reshape
(
w_shape
[
0
],
w_shape
[
2
]
//
4
,
w_shape
[
1
]
*
4
)
return
w_interleaved
w31_scales
=
w13_scale_swapped
.
to
(
torch
.
uint8
)
w31_scales_interleaved
=
_interleave_mxfp4_cutlass_sm90
(
w31_scales
)
w2_scale
=
w2_weight_scale
.
data
.
to
(
torch
.
uint8
)
w2_scale_interleaved
=
_interleave_mxfp4_cutlass_sm90
(
w2_scale
)
return
(
w13_weight_swapped
,
w2_weight
,
w31_scales_interleaved
,
w2_scale_interleaved
,
w13_bias_swapped
,
w2_bias
,
)
elif
mxfp4_backend
==
Mxfp4MoeBackend
.
CK
:
from
vllm._aiter_ops
import
rocm_aiter_ops
if
w13_bias
is
not
None
:
w13_bias
=
w13_bias
.
data
.
to
(
torch
.
float32
)
if
w2_bias
is
not
None
:
w2_bias
=
w2_bias
.
data
.
to
(
torch
.
float32
)
e
,
n
,
k
=
w13_weight
.
shape
# De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks
w13_weight
.
view
(
torch
.
uint8
).
copy_
(
w13_weight
.
data
.
view
(
torch
.
uint8
)
.
view
(
e
,
n
//
2
,
2
,
k
)
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
.
view
(
e
,
n
,
k
)
)
w13_weight_scale
.
data
=
(
w13_weight_scale
.
data
.
view
(
e
,
n
//
2
,
2
,
-
1
)
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
.
view
(
e
,
n
,
-
1
)
)
# View as native FP4 dtype for AITER shuffle
w13_weight
.
data
=
w13_weight
.
data
.
view
(
torch
.
float4_e2m1fn_x2
)
w2_weight
.
data
=
w2_weight
.
data
.
view
(
torch
.
float4_e2m1fn_x2
)
# Shuffle weights and scales for AITER CK kernel layout
w13_weight
.
data
=
rocm_aiter_ops
.
shuffle_weight_a16w4
(
w13_weight
,
16
,
True
)
shuffled_w13_scale
=
rocm_aiter_ops
.
shuffle_scale_a16w4
(
w13_weight_scale
.
view
(
-
1
,
w13_weight_scale
.
shape
[
-
1
]),
num_experts
,
True
,
)
w2_weight
.
data
=
rocm_aiter_ops
.
shuffle_weight_a16w4
(
w2_weight
,
16
,
False
)
shuffled_w2_scale
=
rocm_aiter_ops
.
shuffle_scale_a16w4
(
w2_weight_scale
.
view
(
-
1
,
w2_weight_scale
.
shape
[
-
1
]),
num_experts
,
False
,
)
# Permute bias to match de-interleaved weight layout
if
w13_bias
is
not
None
:
w13_bias
=
(
w13_bias
.
data
.
view
(
-
1
,
n
//
2
,
2
)
.
permute
(
0
,
2
,
1
)
.
contiguous
()
.
view
(
-
1
,
n
)
)
return
(
w13_weight
,
w2_weight
,
shuffled_w13_scale
,
shuffled_w2_scale
,
w13_bias
,
w2_bias
,
)
elif
mxfp4_backend
in
TRITON_BACKENDS
:
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
assert
w13_bias
is
not
None
and
w2_bias
is
not
None
w13_bias
=
w13_bias
.
to
(
torch
.
float32
)
w2_bias
=
w2_bias
.
to
(
torch
.
float32
)
w13_weight
,
w13_flex
,
w13_scale
=
_swizzle_mxfp4
(
w13_weight
,
w13_weight_scale
,
)
w2_weight
,
w2_flex
,
w2_scale
=
_swizzle_mxfp4
(
w2_weight
,
w2_weight_scale
,
)
w13_precision_config
=
PrecisionConfig
(
weight_scale
=
w13_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w13_flex
)
)
w2_precision_config
=
PrecisionConfig
(
weight_scale
=
w2_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
)
)
del
layer
.
w13_weight
del
layer
.
w2_weight
return
(
w13_weight
,
w2_weight
,
w13_precision_config
,
w2_precision_config
,
w13_bias
,
w2_bias
,
)
else
:
raise
ValueError
(
f
"Unsupported mxfp4_backend:
{
mxfp4_backend
}
: "
f
"should be one of:
{
list
(
Mxfp4MoeBackend
)
}
."
)
def
make_mxfp4_moe_quant_config
(
mxfp4_backend
:
Mxfp4MoeBackend
,
w1_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w2_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
FusedMoEQuantConfig
|
None
:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if
mxfp4_backend
in
(
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_MXFP8
,
):
return
mxfp4_mxfp8_moe_quant_config
(
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
)
elif
mxfp4_backend
in
(
Mxfp4MoeBackend
.
MARLIN
,
Mxfp4MoeBackend
.
BATCHED_MARLIN
,
Mxfp4MoeBackend
.
TRITON
,
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
CK
,
):
return
mxfp4_w4a16_moe_quant_config
(
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
)
else
:
return
ocp_mx_moe_quant_config
(
quant_dtype
=
"mxfp4"
,
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
)
def
make_mxfp4_moe_kernel
(
moe_quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
mxfp4_backend
:
Mxfp4MoeBackend
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
)
->
mk
.
FusedMoEKernel
:
"""Create a FusedMoEKernel for the given MXFP4 backend."""
is_monolithic
=
issubclass
(
experts_cls
,
mk
.
FusedMoEExpertsMonolithic
)
# Create Prepare/Finalize.
prepare_finalize
=
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
moe_quant_config
,
routing_tables
=
routing_tables
,
allow_new_interface
=
True
,
use_monolithic
=
is_monolithic
,
)
assert
prepare_finalize
is
not
None
logger
.
info_once
(
"Using %s"
,
prepare_finalize
.
__class__
.
__name__
,
scope
=
"local"
)
# Create Experts.
if
prepare_finalize
.
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
:
max_num_tokens
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens
is
not
None
experts
=
experts_cls
(
moe_config
=
moe_config
,
quant_config
=
moe_quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
)
else
:
experts
=
experts_cls
(
moe_config
=
moe_config
,
quant_config
=
moe_quant_config
,
)
kernel
=
mk
.
FusedMoEKernel
(
prepare_finalize
,
experts
,
shared_experts
=
(
shared_experts
if
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
else
None
),
moe_parallel_config
=
moe_config
.
moe_parallel_config
,
inplace
=
(
not
moe_config
.
disable_inplace
and
mxfp4_backend
not
in
TRTLLM_BACKENDS
),
)
return
kernel
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
View file @
0da93439
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
Fp8MoeBackend
,
backend_to_kernel_cls
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
kMxfp8Dynamic
,
kMxfp8Static
,
)
logger
=
init_logger
(
__name__
)
_SUPPORTED_BACKENDS
:
frozenset
[
Fp8MoeBackend
]
=
frozenset
(
{
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
}
)
class
MxFp8MoeBackend
(
Enum
):
FLASHINFER_TRTLLM
=
"FLASHINFER_TRTLLM"
_BACKEND_NAME_MAP
:
dict
[
str
,
Fp8MoeBackend
]
=
{
"flashinfer_trtllm"
:
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
}
def
_select_kernel_cls
(
backend
:
Fp8MoeBackend
,
config
:
FusedMoEConfig
,
)
->
type
[
mk
.
FusedMoEExperts
]:
"""Select the first supported expert class for the MXFP8 config."""
activation_format
=
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
if
config
.
moe_parallel_config
.
use_batched_activation_format
else
mk
.
FusedMoEActivationFormat
.
Standard
)
last_reason
:
str
|
None
=
None
for
cls
in
backend_to_kernel_cls
(
backend
):
supported
,
reason
=
cls
.
is_supported_config
(
cls
,
config
,
kMxfp8Static
,
kMxfp8Dynamic
,
activation_format
,
)
if
supported
:
return
cls
last_reason
=
reason
raise
ValueError
(
f
"No supported MXFP8 expert class for
{
backend
.
value
}
:
{
last_reason
}
"
)
def
select_mxfp8_moe_backend
(
config
:
FusedMoEConfig
,
)
->
MxFp8MoeBackend
:
)
->
tuple
[
Fp8MoeBackend
,
type
[
mk
.
FusedMoEExperts
]]:
"""Select the MXFP8 MoE backend and the best expert class.
Returns:
A tuple of (fp8_backend, experts_cls).
"""
if
config
.
is_lora_enabled
:
raise
NotImplementedError
(
"LoRA is not supported for MXFP8 MoE."
)
AVAILABLE_BACKENDS
=
[
MxFp8MoeBackend
.
FLASHINFER_TRTLLM
,
]
runner_backend
=
config
.
moe_backend
if
runner_backend
!=
"auto"
:
mapping
=
{
"flashinfer_trtllm"
:
MxFp8MoeBackend
.
FLASHINFER_TRTLLM
,
}
if
backend
:
=
mapping
.
get
(
runner_backend
):
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend (user-requested)."
,
backend
.
value
,
backend
=
_BACKEND_NAME_MAP
.
get
(
runner_backend
)
if
backend
is
None
:
raise
ValueError
(
f
"moe_backend='
{
runner_backend
}
' is not supported for "
f
"MXFP8 MoE. Expected one of "
f
"
{
list
(
_BACKEND_NAME_MAP
.
keys
())
}
."
)
return
backend
raise
ValueError
(
f
"moe_backend='
{
runner_backend
}
' is not supported for MXFP8 MoE. "
f
"Expected one of
{
list
(
mapping
.
keys
())
}
."
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend (user-requested)."
,
backend
.
value
,
)
return
backend
,
_select_kernel_cls
(
backend
,
config
)
# Auto-select: pick the first supported backend.
for
backend
in
_SUPPORTED_BACKENDS
:
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend."
,
backend
.
value
)
return
backend
,
_select_kernel_cls
(
backend
,
config
)
# Auto-select: only one backend available for now.
backend
=
AVAILABLE_BACKENDS
[
0
]
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend."
,
backend
.
value
)
return
backend
raise
ValueError
(
"No MXFP8 MoE backends available."
)
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
View file @
0da93439
...
...
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
mxfp4_w4a16_moe_quant_config
,
nvfp4_moe_quant_config
,
nvfp4_w4a16_moe_quant_config
,
)
...
...
@@ -347,16 +346,6 @@ def convert_to_nvfp4_moe_kernel_format(
)
def
make_mxfp4_moe_quant_config
(
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
)
->
FusedMoEQuantConfig
:
return
mxfp4_w4a16_moe_quant_config
(
w1_scale
=
w13_scale
,
w2_scale
=
w2_scale
,
)
def
make_nvfp4_moe_quant_config
(
backend
:
NvFp4MoeBackend
,
w13_scale
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
0da93439
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kMxfp4Static
,
)
...
...
@@ -201,6 +202,8 @@ def rocm_aiter_fused_experts(
activation_method
=
ActivationMethod
.
SILU
elif
activation
==
MoEActivation
.
GELU
:
activation_method
=
ActivationMethod
.
GELU
elif
activation
==
MoEActivation
.
SWIGLUOAI
:
activation_method
=
rocm_aiter_ops
.
get_aiter_activation_type
(
"swiglu"
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
}
"
)
...
...
@@ -247,8 +250,8 @@ def rocm_aiter_fused_experts(
else
:
quant_method
=
QuantMethod
.
NO
.
value
#
quark moe for mxfp4 w_dtype mxfp4 a_dtype
if
quant_config
.
use_mxfp4_w4a4
:
#
mxfp4: both w4a4 (quark) and w4a16 (oracle CK) use BLOCK_1X32
if
quant_config
.
use_mxfp4_w4a4
or
quant_config
.
use_mxfp4_w4a16
:
quant_method
=
QuantMethod
.
BLOCK_1X32
.
value
# w8a8 block-scaled
if
quant_config
.
block_shape
is
not
None
and
quant_config
.
use_fp8_w8a8
:
...
...
@@ -289,13 +292,20 @@ def rocm_aiter_fused_experts(
doweight_stage1
=
apply_router_weight_on_input
,
num_local_tokens
=
num_local_tokens
,
output_dtype
=
output_dtype
,
bias1
=
quant_config
.
w1_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
bias2
=
quant_config
.
w2_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
)
class
AiterExperts
(
mk
.
FusedMoEExpertsModular
):
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
# When paired with MoRI, the prepare/finalize handles FP8
# quantization during dispatch to reduce network traffic,
# so we should not defer input quantization.
# Otherwise, AITER fused MoE kernels handle input quantization
# internally via a single fused kernel.
return
not
self
.
moe_config
.
use_mori_kernels
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -314,21 +324,23 @@ class AiterExperts(mk.FusedMoEExpertsModular):
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A
=
[
(
None
,
None
),
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kMxfp4Static
,
None
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
GELU
]
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
GELU
,
MoEActivation
.
SWIGLUOAI
,
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
...
...
vllm/model_executor/layers/fused_moe/router/gate_linear.py
View file @
0da93439
...
...
@@ -3,9 +3,11 @@
import
torch
from
torch.nn.parameter
import
Parameter
import
vllm._custom_ops
as
ops
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
@
PluggableLayer
.
register
(
"gate_linear"
)
...
...
@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
4. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
...
...
@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
DSV3_SUPPORTED_NUM_EXPERTS
=
[
256
,
384
]
DSV3_SUPPORTED_HIDDEN_SIZES
=
[
7168
]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS
=
[
32
,
128
]
GPT_OSS_SUPPORTED_HIDDEN_SIZES
=
[
2880
]
def
__init__
(
self
,
input_size
:
int
,
...
...
@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
and
input_size
in
self
.
DSV3_SUPPORTED_HIDDEN_SIZES
)
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
self
.
allow_gpt_oss_router_gemm
=
(
self
.
weight
.
dtype
==
torch
.
bfloat16
and
current_platform
.
is_cuda
()
and
is_hopper_or_blackwell
and
output_size
in
self
.
GPT_OSS_SUPPORTED_NUM_EXPERTS
and
input_size
in
self
.
GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self
.
allow_cublas_router_gemm
=
(
self
.
allow_specialized_router_gemm
...
...
@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
import
vllm._custom_ops
as
ops
# Tier 1: DSV3 specialized kernel
if
self
.
allow_dsv3_router_gemm
and
x
.
shape
[
0
]
<=
16
:
output
=
ops
.
dsv3_router_gemm
(
...
...
@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
)
return
output
,
None
# Tier 2: cuBLAS bf16→fp32
# Tier 2: gpt-oss specialized kernel
if
self
.
allow_gpt_oss_router_gemm
:
output
=
torch
.
ops
.
vllm
.
gpt_oss_router_gemm
(
x
,
self
.
weight
,
self
.
bias
)
return
output
,
None
# Tier 3: cuBLAS bf16→fp32
if
self
.
allow_cublas_router_gemm
and
x
.
dtype
==
torch
.
bfloat16
:
output
=
ops
.
router_gemm_bf16_fp32
(
x
,
self
.
weight
)
return
output
,
None
# Tier
3
: F.linear (ReplicatedLinear)
# Tier
4
: F.linear (ReplicatedLinear)
if
self
.
out_dtype
is
not
None
and
x
.
dtype
!=
self
.
weight
.
dtype
:
x
=
x
.
to
(
self
.
weight
.
dtype
)
output
,
output_bias
=
super
().
forward
(
x
)
if
self
.
out_dtype
is
not
None
and
output
.
dtype
!=
self
.
out_dtype
:
output
=
output
.
to
(
self
.
out_dtype
)
return
output
,
output_bias
def
gpt_oss_router_gemm_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Dynamically run min-latency gemm if num_tokens <= 128.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
if
x
.
shape
[
0
]
<=
128
:
return
ops
.
gpt_oss_router_gemm
(
x
,
weight
,
bias
)
else
:
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
def
gpt_oss_router_gemm_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x
.
new_empty
((
x
.
shape
[
0
],
weight
.
shape
[
0
]))
direct_register_custom_op
(
op_name
=
"gpt_oss_router_gemm"
,
op_func
=
gpt_oss_router_gemm_impl
,
fake_impl
=
gpt_oss_router_gemm_fake
,
)
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
View file @
0da93439
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
...
...
@@ -82,9 +83,22 @@ def _moe_forward(
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
return
layer
.
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
)
runner
=
layer
.
runner
with
runner
.
_sequence_parallel_context
():
if
runner
.
use_dp_chunking
:
return
runner
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
else
:
return
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
def
_moe_forward_fake
(
...
...
@@ -105,9 +119,22 @@ def _moe_forward_shared(
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
return
layer
.
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
)
runner
=
layer
.
runner
with
runner
.
_sequence_parallel_context
():
if
runner
.
use_dp_chunking
:
return
runner
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
else
:
return
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
def
_moe_forward_shared_fake
(
...
...
@@ -191,10 +218,17 @@ class DefaultMoERunner(MoERunner):
self
.
reduce_results
=
reduce_results
self
.
enable_dbo
=
enable_dbo
# Chunked all2all staging tensor
# TODO(bnell) rename these?
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
self
.
_maybe_init_dp_chunking
()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
self
.
use_shared_experts_stream
=
False
if
envs
.
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
logger
.
debug_once
(
"Disabling MoE shared_experts cuda stream"
,
scope
=
"local"
)
self
.
shared_experts_stream
=
None
...
...
@@ -210,23 +244,20 @@ class DefaultMoERunner(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self
.
layer_name
=
layer
.
layer_name
self
.
moe_forward
=
self
.
_select_forward
(
layer
)
def
_select_forward
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Callable
:
if
current_platform
.
is_tpu
()
or
current_platform
.
is_cpu
():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
if
self
.
shared_experts
is
None
:
self
.
moe_forward
=
_moe_forward
else
:
self
.
moe_forward
=
_moe_forward_shared
else
:
if
self
.
shared_experts
is
None
:
self
.
moe_forward
=
torch
.
ops
.
vllm
.
moe_forward
else
:
self
.
moe_forward
=
torch
.
ops
.
vllm
.
moe_forward_shared
return
_moe_forward
if
self
.
shared_experts
is
None
else
_moe_forward_shared
# Chunked all2all staging tensor
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
return
(
torch
.
ops
.
vllm
.
moe_forward
if
self
.
shared_experts
is
None
else
torch
.
ops
.
vllm
.
moe_forward_shared
)
@
property
def
use_dp_chunking
(
self
)
->
bool
:
...
...
@@ -241,22 +272,8 @@ class DefaultMoERunner(MoERunner):
self
,
hidden_states
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
has_separate_shared_experts
:
bool
,
use_chunked_impl
:
bool
,
)
->
tuple
[
bool
,
torch
.
Tensor
|
None
]:
use_shared_experts_stream
=
(
current_platform
.
is_cuda
()
and
has_separate_shared_experts
and
not
use_chunked_impl
and
self
.
shared_experts_stream
is
not
None
and
(
hidden_states
.
shape
[
0
]
<=
envs
.
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
shared_experts_input
:
torch
.
Tensor
|
None
=
None
if
use_shared_experts_stream
:
):
if
self
.
use_shared_experts_stream
:
assert
self
.
shared_experts_stream
is
not
None
assert
self
.
moe_config
.
disable_inplace
...
...
@@ -278,12 +295,11 @@ class DefaultMoERunner(MoERunner):
assert
self
.
shared_experts_stream
is
not
None
self
.
shared_experts_stream
.
wait_stream
(
current_stream
())
return
use_shared_experts_stream
,
shared_experts_input
def
ensure_dp_chunking_init
(
self
):
if
not
self
.
use_dp_chunking
or
self
.
batched_hidden_states
is
not
None
:
def
_maybe_init_dp_chunking
(
self
):
if
not
self
.
use_dp_chunking
:
return
assert
self
.
batched_hidden_states
is
None
states_shape
:
tuple
[
int
,
...]
logits_shape
:
tuple
[
int
,
...]
...
...
@@ -309,6 +325,38 @@ class DefaultMoERunner(MoERunner):
device
=
device
,
)
@
property
def
has_separate_shared_experts
(
self
)
->
bool
:
return
(
not
self
.
quant_method
.
mk_owns_shared_expert
and
self
.
shared_experts
is
not
None
)
def
_apply_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
,
allow_streaming
:
bool
=
False
,
)
->
torch
.
Tensor
|
None
:
shared_output
:
torch
.
Tensor
|
None
=
None
if
self
.
has_separate_shared_experts
:
assert
self
.
shared_experts
is
not
None
if
self
.
use_shared_experts_stream
and
allow_streaming
:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with
torch
.
cuda
.
stream
(
self
.
shared_experts_stream
):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output
=
self
.
shared_experts
(
hidden_states
)
current_stream
().
wait_stream
(
self
.
shared_experts_stream
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
return
shared_output
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
"""
The shared_experts are typically computed using the RowParallelLinear
...
...
@@ -322,7 +370,6 @@ class DefaultMoERunner(MoERunner):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert
self
.
quant_method
is
not
None
return
(
self
.
quant_method
.
moe_kernel
is
not
None
and
self
.
quant_method
.
moe_kernel
.
output_is_reduced
()
...
...
@@ -357,7 +404,7 @@ class DefaultMoERunner(MoERunner):
return
result
return
hidden_states
def
_reduce_output
(
def
_maybe
_reduce_output
(
self
,
states
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
trunc_sizes
:
list
[
int
],
...
...
@@ -397,25 +444,21 @@ class DefaultMoERunner(MoERunner):
return
"from_forward_context"
return
self
.
layer_name
def
forward
(
def
_maybe_pad_hidden_states
(
self
,
original_hidden_states
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if
self
.
shared_experts
is
not
None
:
original_hidden_states
=
hidden_states
original_hidden_dim
=
hidden_states
.
shape
[
-
1
]
else
:
original_hidden_states
=
None
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states
=
self
.
apply_routed_input_transform
(
hidden_states
)
# This is the dimension after transform (for routed expert output slicing)
)
->
tuple
[
torch
.
Tensor
,
list
[
int
]]:
original_hidden_dim
=
(
original_hidden_states
.
shape
[
-
1
]
if
original_hidden_states
is
not
None
else
0
)
transformed_hidden_dim
=
hidden_states
.
shape
[
-
1
]
if
self
.
moe_config
.
hidden_dim
!=
transformed_hidden_dim
:
if
(
not
self
.
quant_method
.
skip_forward_padding
and
self
.
moe_config
.
hidden_dim
!=
transformed_hidden_dim
):
hidden_states
=
F
.
pad
(
hidden_states
,
(
0
,
self
.
moe_config
.
hidden_dim
-
transformed_hidden_dim
),
...
...
@@ -423,134 +466,235 @@ class DefaultMoERunner(MoERunner):
value
=
0.0
,
)
fused_output
=
self
.
moe_forward
(
hidden_states
,
router_logits
,
original_hidden_states
,
self
.
_encode_layer_name
(),
)
if
self
.
shared_experts
is
not
None
:
orig_hidden_dims
=
[
original_hidden_dim
,
transformed_hidden_dim
]
else
:
orig_hidden_dims
=
[
transformed_hidden_dim
]
return
self
.
_reduce_output
(
fused_output
,
orig_hidden_dims
)
return
hidden_states
,
orig_hidden_dims
def
forward_impl_chunke
d
(
def
_apply_quant_metho
d
(
self
,
layer
:
torch
.
nn
.
Module
,
full_hidden_states
:
torch
.
Tensor
,
full_router_logits
:
torch
.
Tensor
,
full_shared_input
:
torch
.
Tensor
|
None
,
has_separate_shared_experts
:
bool
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
run_shared_experts_before
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
shared_input
=
shared_input
if
shared_input
is
not
None
else
hidden_states
shared_output
:
torch
.
Tensor
|
None
=
None
# Run this before quant_method to avoid inplace issues.
if
run_shared_experts_before
:
shared_output
=
self
.
_apply_shared_experts
(
shared_input
,
False
)
if
self
.
quant_method
.
is_monolithic
:
result
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
x
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
topk_weights
,
topk_ids
=
self
.
router
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
result
=
self
.
quant_method
.
apply
(
layer
=
layer
,
x
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
)
if
isinstance
(
result
,
tuple
):
assert
shared_output
is
None
shared_output
,
hidden_states
=
result
else
:
hidden_states
=
result
if
not
run_shared_experts_before
and
self
.
has_separate_shared_experts
:
assert
shared_output
is
None
shared_output
=
self
.
_apply_shared_experts
(
shared_input
,
True
)
return
shared_output
,
hidden_states
def
_sequence_parallel_context
(
self
):
ctx
=
get_forward_context
()
return
(
ctx
.
dp_metadata
.
sp_local_sizes
(
self
.
moe_config
.
sp_size
)
if
ctx
.
dp_metadata
else
nullcontext
()
)
def
_allocate_dp_chunking_outputs
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
assert
self
.
use_dp_chunking
# Assert the inputs are of the proper type and shape.
assert
self
.
batched_hidden_states
is
not
None
assert
self
.
batched_router_logits
is
not
None
assert
self
.
batched_hidden_states
.
dtype
==
full_hidden_states
.
dtype
,
(
f
"
{
self
.
batched_hidden_states
.
dtype
}
==
{
full_hidden_states
.
dtype
}
"
assert
self
.
batched_hidden_states
.
dtype
==
hidden_states
.
dtype
,
(
f
"
{
self
.
batched_hidden_states
.
dtype
}
==
{
hidden_states
.
dtype
}
"
)
assert
self
.
batched_router_logits
.
dtype
==
full_
router_logits
.
dtype
,
(
f
"
{
self
.
batched_router_logits
.
dtype
}
==
{
full_
router_logits
.
dtype
}
"
assert
self
.
batched_router_logits
.
dtype
==
router_logits
.
dtype
,
(
f
"
{
self
.
batched_router_logits
.
dtype
}
==
{
router_logits
.
dtype
}
"
)
# Check size compatibility.
assert
self
.
batched_hidden_states
.
size
(
-
1
)
==
full_hidden_states
.
size
(
-
1
)
assert
self
.
batched_router_logits
.
size
(
-
1
)
==
full_router_logits
.
size
(
-
1
)
# TODO(bnell): Fix shared_expert_inputs w/chunking.
# assert shared_input is None, (
# "Routed input transform is not currently supported with DP chunking."
# )
# Check size compatibility.
assert
self
.
batched_hidden_states
.
size
(
-
1
)
==
hidden_states
.
size
(
-
1
)
assert
self
.
batched_router_logits
.
size
(
-
1
)
==
router_logits
.
size
(
-
1
)
f
ul
l_fused_
final_
hidden_states
=
torch
.
empty_like
(
full_
hidden_states
)
f
ina
l_fused_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
self
.
shared_experts
is
not
None
:
full_shared_final_hidden_states
=
torch
.
empty_like
(
full_hidden_states
)
def
process_chunk
(
chunk_start
,
chunk_end
,
skip_result_store
=
False
):
chunk_size
=
chunk_end
-
chunk_start
hidden_states
=
full_hidden_states
[
chunk_start
:
chunk_end
,
:]
router_logits
=
full_router_logits
[
chunk_start
:
chunk_end
,
:]
shared_input
=
(
full_shared_input
[
chunk_start
:
chunk_end
,
:]
if
full_shared_input
is
not
None
else
None
)
final_shared_hidden_states
=
torch
.
empty_like
(
hidden_states
)
else
:
final_shared_hidden_states
=
None
assert
self
.
batched_hidden_states
is
not
None
assert
self
.
batched_router_logits
is
not
None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if
self
.
batched_hidden_states
.
dim
()
==
3
:
assert
self
.
batched_router_logits
.
dim
()
==
3
batch_buffer_idx
=
dbo_current_ubatch_id
()
batched_hidden_states
=
self
.
batched_hidden_states
[
batch_buffer_idx
,
:]
batched_router_logits
=
self
.
batched_router_logits
[
batch_buffer_idx
,
:]
else
:
batched_hidden_states
=
self
.
batched_hidden_states
batched_router_logits
=
self
.
batched_router_logits
return
final_shared_hidden_states
,
final_fused_hidden_states
def
_maybe_gate
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if
self
.
gate
is
not
None
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
return
router_logits
@
property
def
do_naive_dispatch_combine
(
self
)
->
bool
:
return
(
self
.
moe_config
.
dp_size
>
1
and
not
self
.
quant_method
.
supports_internal_mk
)
assert
(
batched_hidden_states
.
size
(
0
)
# type: ignore
>=
chunk_size
def
_maybe_dispatch
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if
self
.
do_naive_dispatch_combine
:
hidden_states
,
router_logits
=
get_ep_group
().
dispatch_router_logits
(
hidden_states
,
router_logits
,
self
.
moe_config
.
is_sequence_parallel
,
)
assert
(
batched_router_logits
.
size
(
0
)
# type: ignore
>=
chunk_size
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstraction to better support PCP.
if
self
.
moe_config
.
pcp_size
>
1
:
hidden_states
=
get_pcp_group
().
all_gather
(
hidden_states
,
dim
=
0
,
)
staged_hidden_states
=
batched_hidden_states
[:
chunk_size
,
:]
# type: ignore
staged_router_logits
=
batched_router_logits
[:
chunk_size
,
:]
# type: ignore
staged_hidden_states
.
copy_
(
hidden_states
,
non_blocking
=
True
)
staged_router_logits
.
copy_
(
router_logits
,
non_blocking
=
True
)
router_logits
=
get_pcp_group
().
all_gather
(
router_logits
,
dim
=
0
,
)
return
hidden_states
,
router_logits
shared_input
=
(
shared_input
if
shared_input
is
not
None
else
staged_hidden_states
def
_maybe_combine
(
self
,
shared_output
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
self
.
do_naive_dispatch_combine
:
hidden_states
=
get_ep_group
().
combine
(
hidden_states
,
self
.
moe_config
.
is_sequence_parallel
)
# Matrix multiply.
if
self
.
quant_method
.
is_monolithic
:
assert
has_separate_shared_experts
or
self
.
shared_experts
is
None
final_hidden_states
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
x
=
staged_hidden_states
,
router_logits
=
staged_router_logits
,
)
else
:
topk_weights
,
topk_ids
=
self
.
router
.
select_experts
(
hidden_states
=
staged_hidden_states
,
router_logits
=
staged_router_logits
,
)
if
self
.
moe_config
.
pcp_size
>
1
:
hidden_states
=
get_pcp_group
().
reduce_scatter
(
hidden_states
,
dim
=
0
,
)
# need RS for shared_output?
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
layer
,
x
=
staged_hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
)
if
self
.
shared_experts
is
not
None
:
assert
shared_output
is
not
None
return
shared_output
,
hidden_states
else
:
return
hidden_states
if
has_separate_shared_experts
:
assert
not
isinstance
(
final_hidden_states
,
tuple
)
assert
self
.
shared_experts
is
not
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if
self
.
shared_experts
is
not
None
:
original_hidden_states
=
hidden_states
else
:
original_hidden_states
=
None
shared_output
=
self
.
shared_experts
(
shared_input
)
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states
=
self
.
apply_routed_input_transform
(
hidden_states
)
final
_hidden_states
=
(
shared_output
,
final_
hidden_states
,
)
hidden_states
,
og_hidden_dims
=
self
.
_maybe_pad
_hidden_states
(
original_hidden_states
,
hidden_states
,
)
if
not
skip_result_store
:
if
self
.
shared_experts
is
None
:
full_fused_final_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
final_hidden_states
,
non_blocking
=
True
)
else
:
full_shared_final_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
final_hidden_states
[
0
],
non_blocking
=
True
)
full_fused_final_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
final_hidden_states
[
1
],
non_blocking
=
True
)
fused_output
=
self
.
moe_forward
(
hidden_states
,
router_logits
,
original_hidden_states
,
self
.
_encode_layer_name
(),
)
return
self
.
_maybe_reduce_output
(
fused_output
,
og_hidden_dims
)
def
_slice_and_copy_input
(
self
,
out_slice
:
torch
.
Tensor
,
orig
:
torch
.
Tensor
|
None
,
start
:
int
,
end
:
int
,
)
->
torch
.
Tensor
:
assert
orig
is
not
None
slice_size
=
end
-
start
orig_slice
=
orig
[
start
:
end
,
:]
if
self
.
enable_dbo
:
assert
out_slice
.
dim
()
==
3
batch_buffer_idx
=
dbo_current_ubatch_id
()
out_slice
=
out_slice
[
batch_buffer_idx
,
:]
assert
out_slice
.
size
(
0
)
>=
slice_size
out_slice
=
out_slice
[:
slice_size
,
:]
out_slice
.
copy_
(
orig_slice
,
non_blocking
=
True
)
return
out_slice
def
forward_impl_chunked
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Gate overlap not supported when chunking is enabled. Run the
# gate first.
router_logits
=
self
.
_maybe_gate
(
hidden_states
,
router_logits
)
final_shared_hidden_states
,
final_fused_hidden_states
=
(
self
.
_allocate_dp_chunking_outputs
(
hidden_states
,
router_logits
)
)
ctx
=
get_forward_context
()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
...
...
@@ -564,7 +708,7 @@ class DefaultMoERunner(MoERunner):
max_tokens_across_dispatchers
,
self
.
moe_config
.
sp_size
)
num_tokens
=
full_
hidden_states
.
size
(
0
)
num_tokens
=
hidden_states
.
size
(
0
)
for
chunk_idx
,
chunk_start_
in
enumerate
(
range
(
0
,
max_tokens_across_dispatchers
,
moe_dp_chunk_size_per_rank
)
):
...
...
@@ -575,17 +719,55 @@ class DefaultMoERunner(MoERunner):
# clamp start and end
chunk_start
=
min
(
chunk_start
,
num_tokens
-
1
)
chunk_end
=
min
(
chunk_end
,
num_tokens
)
with
ctx
.
dp_metadata
.
chunked_sizes
(
chunk_sizes
=
ctx
.
dp_metadata
.
chunked_sizes
(
self
.
moe_config
.
sp_size
,
moe_dp_chunk_size_per_rank
,
chunk_idx
):
process_chunk
(
chunk_start
,
chunk_end
,
skip_result_store
=
chunk_start_
>=
num_tokens
)
with
chunk_sizes
:
hidden_states_chunk
=
self
.
_slice_and_copy_input
(
self
.
batched_hidden_states
,
hidden_states
,
chunk_start
,
chunk_end
,
)
router_logits_chunk
=
self
.
_slice_and_copy_input
(
self
.
batched_router_logits
,
router_logits
,
chunk_start
,
chunk_end
,
)
shared_input_chunk
=
(
shared_input
[
chunk_start
:
chunk_end
,
:]
if
shared_input
is
not
None
else
None
)
shared_output_chunk
,
hidden_states_chunk
=
self
.
_apply_quant_method
(
layer
=
layer
,
hidden_states
=
hidden_states_chunk
,
router_logits
=
router_logits_chunk
,
shared_input
=
shared_input_chunk
,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if
chunk_start
<
num_tokens
:
final_fused_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
hidden_states_chunk
,
non_blocking
=
True
)
if
self
.
shared_experts
is
not
None
:
assert
shared_output_chunk
is
not
None
assert
final_shared_hidden_states
is
not
None
final_shared_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
shared_output_chunk
,
non_blocking
=
True
)
if
self
.
shared_experts
is
None
:
return
f
ul
l_fused_
final_
hidden_states
return
f
ina
l_fused_hidden_states
else
:
return
(
full_shared_final_hidden_states
,
full_fused_final_hidden_states
)
assert
final_shared_hidden_states
is
not
None
return
(
final_shared_hidden_states
,
final_fused_hidden_states
)
def
forward_impl
(
self
,
...
...
@@ -594,148 +776,51 @@ class DefaultMoERunner(MoERunner):
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
quant_method
is
not
None
self
.
ensure_dp_chunking_init
()
has_separate_shared_experts
=
(
not
self
.
quant_method
.
mk_owns_shared_expert
and
self
.
shared_experts
is
not
None
self
.
use_shared_experts_stream
=
(
current_platform
.
is_cuda
()
and
self
.
has_separate_shared_experts
and
not
self
.
use_dp_chunking
and
self
.
shared_experts_stream
is
not
None
and
(
hidden_states
.
shape
[
0
]
<=
envs
.
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
use_chunked_impl
=
self
.
use_dp_chunking
# Check if we need to run shared experts before matrix multiply because
# matrix multiply may modify the hidden_states.
run_shared_experts_before
=
(
self
.
has_separate_shared_experts
and
not
self
.
use_shared_experts_stream
)
use_shared_experts_stream
,
shared_experts_input
=
(
# The shared experts stream must be set up before calling the gate so they
# can be overlapped.
if
not
run_shared_experts_before
:
self
.
_maybe_setup_shared_experts_stream
(
hidden_states
,
shared_input
,
has_separate_shared_experts
,
use_chunked_impl
,
)
)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if
self
.
gate
is
not
None
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
use_chunked_impl
:
return
self
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_input
,
has_separate_shared_experts
,
)
router_logits
=
self
.
_maybe_gate
(
hidden_states
,
router_logits
)
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine
=
(
self
.
moe_config
.
dp_size
>
1
and
not
self
.
quant_method
.
supports_internal_mk
# TODO(bnell): parts of the dispatch/combine steps will go away once
# #32567 lands and the remaining kernels are made MKs. The PCP
# code will probably remain
hidden_states
,
router_logits
=
self
.
_maybe_dispatch
(
layer
,
hidden_states
,
router_logits
,
)
ctx
=
get_forward_context
()
sp_ctx
=
(
ctx
.
dp_metadata
.
sp_local_sizes
(
self
.
moe_config
.
sp_size
)
if
ctx
.
dp_metadata
else
nullcontext
()
shared_output
,
hidden_states
=
self
.
_apply_quant_method
(
layer
=
layer
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_input
=
shared_input
,
run_shared_experts_before
=
run_shared_experts_before
,
)
with
sp_ctx
:
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if
has_separate_shared_experts
and
not
use_shared_experts_stream
:
assert
self
.
shared_experts
is
not
None
shared_input
=
(
shared_input
if
shared_input
is
not
None
else
hidden_states
)
shared_output
=
self
.
shared_experts
(
shared_input
)
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if
do_naive_dispatch_combine
:
hidden_states
,
router_logits
=
get_ep_group
().
dispatch_router_logits
(
hidden_states
,
router_logits
,
self
.
moe_config
.
is_sequence_parallel
,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if
self
.
moe_config
.
pcp_size
>
1
:
hidden_states
=
get_pcp_group
().
all_gather
(
hidden_states
,
dim
=
0
,
)
router_logits
=
get_pcp_group
().
all_gather
(
router_logits
,
dim
=
0
,
)
# Matrix multiply.
if
self
.
quant_method
.
is_monolithic
:
final_hidden_states
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
x
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
topk_weights
,
topk_ids
=
self
.
router
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
layer
,
x
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
)
if
has_separate_shared_experts
:
assert
self
.
shared_experts
is
not
None
if
use_shared_experts_stream
:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with
torch
.
cuda
.
stream
(
self
.
shared_experts_stream
):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output
=
self
.
shared_experts
(
shared_experts_input
)
current_stream
().
wait_stream
(
self
.
shared_experts_stream
)
final_hidden_states
=
(
shared_output
,
final_hidden_states
,
)
def
combine_output
(
states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
do_naive_dispatch_combine
:
states
=
get_ep_group
().
combine
(
states
,
self
.
moe_config
.
is_sequence_parallel
)
if
self
.
moe_config
.
pcp_size
>
1
:
states
=
get_pcp_group
().
reduce_scatter
(
states
,
dim
=
0
,
)
return
states
if
self
.
shared_experts
is
not
None
:
return
(
final_hidden_states
[
0
],
combine_output
(
final_hidden_states
[
1
]),
)
else
:
return
combine_output
(
final_hidden_states
)
return
self
.
_maybe_combine
(
shared_output
,
hidden_states
,
)
vllm/model_executor/layers/fused_moe/utils.py
View file @
0da93439
...
...
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
per_tensor_dequantize
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
...
@@ -199,7 +200,7 @@ def _mxfp8_e4m3_quantize(
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
A_scale
is
None
assert
not
per_act_token_quant
assert
block_shape
is
None
assert
block_shape
is
None
or
block_shape
==
[
1
,
32
]
return
mxfp8_e4m3_quantize
(
A
,
is_sf_swizzled_layout
)
...
...
@@ -265,7 +266,7 @@ def moe_kernel_quantize_input(
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if
quant_dtype
==
torch
.
float8_e4m3fn
:
if
quant_dtype
==
current_platform
.
fp8_dtype
()
:
return
_fp8_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
)
elif
quant_dtype
==
torch
.
int8
:
return
_int8_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
)
...
...
@@ -316,27 +317,6 @@ def normalize_batched_scales_shape(
return
scales
def
_validate_scale_shape
(
a
:
torch
.
Tensor
,
a_scale
:
torch
.
Tensor
|
None
,
per_act_token_quant
:
bool
,
block_shape
:
list
[
int
]
|
None
,
)
->
None
:
if
a_scale
is
None
:
return
if
not
per_act_token_quant
and
block_shape
is
None
:
assert
a_scale
.
numel
()
==
1
,
f
"
{
a_scale
.
shape
}
"
elif
per_act_token_quant
:
assert
a_scale
.
shape
[
0
]
==
a
.
shape
[
0
]
and
a_scale
.
shape
[
1
]
==
1
,
(
f
"
{
a_scale
.
shape
[
0
]
}
==
{
a
.
shape
[
0
]
}
and
{
a_scale
.
shape
[
1
]
}
== 1"
)
else
:
assert
block_shape
is
not
None
expected
=
(
a
.
shape
[
0
],
cdiv
(
a
.
shape
[
1
],
block_shape
[
1
]))
assert
a_scale
.
shape
==
expected
,
f
"
{
a_scale
.
shape
}
==
{
expected
}
"
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378
...
...
vllm/model_executor/layers/kda.py
View file @
0da93439
...
...
@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
constant_caches
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
constant_caches
=
self
.
kv_cache
[
0
]
q_proj_states
=
q_proj_states
[:
num_actual_tokens
]
k_proj_states
=
k_proj_states
[:
num_actual_tokens
]
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
0da93439
...
...
@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact
=
qkvact
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
kv_cache
=
self
.
kv_cache
[
0
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
clear_linear_attention_cache_for_new_sequences
(
kv_cache
,
state_indices_tensor
,
attn_metadata
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
0da93439
...
...
@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
0da93439
...
...
@@ -574,7 +574,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
0da93439
...
...
@@ -333,13 +333,13 @@ def selective_state_update(
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
if
out
.
dim
()
==
2
:
out
=
out
.
unsqueeze
(
1
)
if
num_accepted_tokens
is
not
None
:
assert
state_batch_indices
is
not
None
and
state_batch_indices
.
dim
()
==
2
assert
dst_state_batch_indices
is
None
or
dst_state_batch_indices
.
dim
()
==
2
if
state_batch_indices
is
not
None
and
state_batch_indices
.
dim
()
==
1
:
state_batch_indices
=
state_batch_indices
.
unsqueeze
(
1
)
if
dst_state_batch_indices
is
not
None
and
dst_state_batch_indices
.
dim
()
==
1
:
dst_state_batch_indices
=
dst_state_batch_indices
.
unsqueeze
(
1
)
if
num_accepted_tokens
is
not
None
:
assert
state_batch_indices
is
not
None
and
state_batch_indices
.
dim
()
==
2
assert
dst_state_batch_indices
is
None
or
dst_state_batch_indices
.
dim
()
==
2
_
,
nheads
,
dim
,
dstate
=
state
.
shape
batch
=
x
.
shape
[
0
]
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
0da93439
...
...
@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
...
...
vllm/model_executor/layers/pooler/activations.py
View file @
0da93439
...
...
@@ -16,25 +16,22 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
def
get_
classification_
act_fn
(
def
get_act_fn
(
config
:
PretrainedConfig
,
static_num_labels
:
bool
=
True
,
)
->
"PoolerActivation"
:
# get classification act_fn
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
def
get_cross_encoder_act_fn
(
config
:
PretrainedConfig
,
)
->
"PoolerActivation"
:
# get cross_encoder act_fn
function_name
:
str
|
None
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
...
...
@@ -55,24 +52,16 @@ def get_cross_encoder_act_fn(
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
return
PoolerClassify
()
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
def
resolve_classifier_act_fn
(
model_config
:
ModelConfig
,
static_num_labels
:
bool
=
True
,
act_fn
:
"PoolerActivation |
str |
None"
=
None
,
act_fn
:
"PoolerActivation | None"
=
None
,
):
if
isinstance
(
act_fn
,
str
):
if
act_fn
==
"classify"
:
return
get_classification_act_fn
(
model_config
.
hf_config
)
if
act_fn
==
"score"
:
return
get_cross_encoder_act_fn
(
model_config
.
hf_config
)
raise
ValueError
(
f
"act_fn [
{
act_fn
=
}
] not supported."
)
if
act_fn
is
None
:
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
return
get_act_fn
(
model_config
.
hf_config
,
static_num_labels
)
assert
callable
(
act_fn
)
return
act_fn
...
...
@@ -97,9 +86,8 @@ class PoolerActivation(nn.Module, ABC):
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
# classify -> (batch_size, num_classes)
# embed -> (batch_size, embedding_size) or list(embedding_size)
if
isinstance
(
pooled_data
,
list
):
return
[
self
.
forward_chunk
(
data
)
for
data
in
pooled_data
]
...
...
vllm/model_executor/layers/pooler/seqwise/heads.py
View file @
0da93439
...
...
@@ -56,29 +56,31 @@ class EmbeddingPoolerHead(SequencePoolerHead):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_
dimension
]
# pooled_data shape: [batchsize, hidden_
size
]
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
embeddings
=
self
.
projector
(
pooled_data
)
else
:
embeddings
=
pooled_data
# embeddings shape: [batchsize, embedding_size]
# for matryoshka representation
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
]
if
any
(
d
is
not
None
for
d
in
dimensions_list
):
# change the output dimension
assert
len
(
pooled_data
)
==
len
(
dimensions_list
)
if
len
(
set
(
dimensions_list
))
==
1
and
not
isinstance
(
pooled_data
,
list
):
assert
len
(
embeddings
)
==
len
(
dimensions_list
)
if
len
(
set
(
dimensions_list
))
==
1
and
not
isinstance
(
embeddings
,
list
):
# if all dimensions are the same
d
=
dimensions_list
[
0
]
pooled_data
=
pooled_data
[...,
:
d
]
embeddings
=
embeddings
[...,
:
d
]
else
:
pooled_data
=
[
embeddings
=
[
vecs
if
d
is
None
else
vecs
[...,
:
d
]
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
for
vecs
,
d
in
zip
(
embeddings
,
dimensions_list
)
]
# for normalize
...
...
@@ -86,15 +88,15 @@ class EmbeddingPoolerHead(SequencePoolerHead):
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
pooled_data
=
self
.
activation
(
pooled_data
)
embeddings
=
self
.
activation
(
embeddings
)
else
:
pooled_data
=
[
embeddings
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
for
vecs
,
f
in
zip
(
embeddings
,
flags
)
]
#
pooled_data
shape: [batchsize, embedding_
dimension
]
return
pooled_data
#
embeddings
shape: [batchsize, embedding_
size
]
return
embeddings
class
ClassifierPoolerHead
(
SequencePoolerHead
):
...
...
@@ -113,7 +115,7 @@ class ClassifierPoolerHead(SequencePoolerHead):
self
.
activation
=
activation
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
return
{
"classify"
}
def
forward
(
self
,
...
...
@@ -131,21 +133,23 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
logits
=
self
.
classifier
(
pooled_data
)
else
:
logits
=
pooled_data
# logits shape: [batchsize, num_labels]
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
logits
-=
self
.
logit_bias
if
self
.
activation
is
not
None
:
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
pooled_data
=
self
.
activation
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
logits
=
self
.
activation
(
logits
)
if
flags
[
0
]
else
logits
else
:
pooled_data
=
[
logits
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
for
vecs
,
f
in
zip
(
logits
,
flags
)
]
#
pooled_data
shape: [batchsize, num_labels]
return
pooled_data
#
logits
shape: [batchsize, num_labels]
return
logits
Prev
1
…
20
21
22
23
24
25
26
27
28
…
31
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment