Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5dcd7ef1
Unverified
Commit
5dcd7ef1
authored
Jan 07, 2026
by
Robert Shaw
Committed by
GitHub
Jan 07, 2026
Browse files
[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
parent
ffc0a279
Changes
38
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1252 additions
and
1360 deletions
+1252
-1360
tests/kernels/moe/test_flashinfer.py
tests/kernels/moe/test_flashinfer.py
+37
-12
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+12
-7
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+0
-2
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+17
-145
vllm/model_executor/layers/fused_moe/fallback.py
vllm/model_executor/layers/fused_moe/fallback.py
+126
-0
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
.../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+5
-5
vllm/model_executor/layers/fused_moe/oracle/__init__.py
vllm/model_executor/layers/fused_moe/oracle/__init__.py
+2
-0
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+358
-0
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
+75
-0
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+16
-107
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+135
-355
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+112
-384
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+147
-264
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+3
-4
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+119
-53
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+71
-4
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+1
-1
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+16
-17
No files found.
tests/kernels/moe/test_flashinfer.py
View file @
5dcd7ef1
...
...
@@ -11,12 +11,17 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_flashinfer_per_tensor_scale_fp8
,
flashinfer_cutlass_moe_fp8
,
apply_fi_trtllm_fp8_per_tensor_moe
,
register_scales_for_trtllm_fp8_per_tensor_moe
,
rotate_
flashinfer_fp8_moe_weights
,
rotate_
weights_for_fi_trtllm_fp8_per_tensor_moe
,
swap_w13_to_w31
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
input_to_float8
...
...
@@ -103,6 +108,7 @@ class TestData:
w2_quantized
,
w2_weight_scale
=
quant_fp8_per_tensor_batches
(
w2
)
layer
=
torch
.
nn
.
Module
()
layer
.
orig_dtype
=
torch
.
bfloat16
layer
.
w13_weight
=
w13_quantized
.
clone
()
layer
.
w2_weight
=
w2_quantized
.
clone
()
layer
.
w13_input_scale
=
a1_scale
...
...
@@ -115,10 +121,10 @@ class TestData:
pcp_size
=
1
,
dp_size
=
1
,
ep_size
=
1
,
tp_rank
=
1
,
pcp_rank
=
1
,
dp_rank
=
1
,
ep_rank
=
1
,
tp_rank
=
0
,
pcp_rank
=
0
,
dp_rank
=
0
,
ep_rank
=
0
,
use_ep
=
False
,
all2all_backend
=
"naive"
,
)
...
...
@@ -126,7 +132,9 @@ class TestData:
# flashinfer expects swapped rows for w13
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
if
is_trtllm
:
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe
(
layer
.
w13_weight
,
layer
.
w2_weight
)
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
,
layer
.
w13_weight_scale
,
...
...
@@ -199,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config
=
quant_config
,
)
flashinfer_output
=
apply_f
lashinfer
_per_tensor_
scale_fp8
(
flashinfer_output
=
apply_f
i_trtllm_fp8
_per_tensor_
moe
(
layer
=
td
.
layer
,
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
...
...
@@ -277,17 +285,34 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td
.
layer
.
get_fused_moe_quant_config
=
get_fused_moe_quant_config
td
.
layer
.
quant_method
=
td
.
layer
flashinfer_cutlass_output
=
flashinfer_cutlass_moe_fp8
(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
quant_config
.
is_block_quantized
),
FlashInferExperts
(
out_dtype
=
td
.
layer
.
orig_dtype
,
quant_config
=
quant_config
,
ep_rank
=
td
.
layer
.
moe_parallel_config
.
ep_rank
,
ep_size
=
td
.
layer
.
moe_parallel_config
.
ep_size
,
tp_rank
=
td
.
layer
.
moe_parallel_config
.
tp_rank
,
tp_size
=
td
.
layer
.
moe_parallel_config
.
tp_size
,
use_dp
=
False
,
use_deepseek_fp8_block_scale
=
False
,
),
)
flashinfer_cutlass_output
=
kernel
(
td
.
hidden_states
,
td
.
layer
,
td
.
layer
.
w13_weight
,
td
.
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
False
,
activation
=
activation
,
global_num_experts
=
e
,
expert_map
=
None
,
apply_router_weight_on_input
=
True
,
)
torch
.
testing
.
assert_close
(
output
,
flashinfer_cutlass_output
,
atol
=
5.5e-2
,
rtol
=
1e-2
)
tests/quantization/test_fp8.py
View file @
5dcd7ef1
...
...
@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config
,
Fp8KVCacheMethod
,
Fp8LinearMethod
,
Fp8MoeBackend
,
Fp8MoEMethod
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -278,8 +277,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@
pytest
.
mark
.
parametrize
(
"use_marlin"
,
[
False
])
# skip True
def
test_fp8_reloading
(
method_cls
,
is_checkpoint_fp8_serialized
,
weight_block_size
,
use_marlin
,
dist_init
method_cls
,
is_checkpoint_fp8_serialized
,
weight_block_size
,
use_marlin
,
dist_init
,
monkeypatch
,
):
# NOTE(rob): this test fails when using DeepGEMM because the
# shapes are invalid. Previously the test was passing because
# we set fp8_backend to None, which sidestepped the issue.
monkeypatch
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
)
if
is_checkpoint_fp8_serialized
is
False
:
pytest
.
skip
(
"FP8 weight reloading does not support online quantization"
)
...
...
@@ -307,6 +316,7 @@ def test_fp8_reloading(
params_dtype
=
torch
.
bfloat16
,
weight_loader
=
default_weight_loader
,
)
method
.
use_marlin
=
use_marlin
else
:
layer
=
FusedMoE
(
...
...
@@ -325,11 +335,6 @@ def test_fp8_reloading(
weight_loader
=
default_weight_loader
,
)
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method
.
use_marlin
=
use_marlin
method
.
fp8_backend
=
Fp8MoeBackend
.
MARLIN
if
use_marlin
else
None
# capture weights format during loading
original_metadata
=
[
(
name
,
param
.
shape
,
getattr
(
param
,
"weight_loader"
,
default_weight_loader
))
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
5dcd7ef1
...
...
@@ -73,7 +73,6 @@ if HAS_TRITON:
CutlassExpertsFp8
,
CutlassExpertsW4A8Fp8
,
cutlass_moe_fp4
,
cutlass_moe_fp8
,
cutlass_moe_w4a8_fp8
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
...
...
@@ -96,7 +95,6 @@ if HAS_TRITON:
"fused_experts"
,
"get_config_file_name"
,
"GroupedTopk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
"cutlass_moe_w4a8_fp8"
,
"CutlassExpertsFp8"
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
5dcd7ef1
...
...
@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
e
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
device
:
torch
.
dtype
,
):
assert
quant_config
.
use_fp8_w8a8
super
().
__init__
(
quant_config
)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
ab_strides1_c_strides2
=
torch
.
full
((
e
,),
k
,
device
=
device
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,),
n
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,),
2
*
n
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
out_dtype
=
out_dtype
self
.
ab_strides1
=
ab_strides1
self
.
ab_strides1
=
ab_strides1
_c_strides2
self
.
ab_strides2
=
ab_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
c_strides2
self
.
c_strides2
=
ab_strides1_
c_strides2
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
...
...
@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
def
__init__
(
self
,
out_dtype
:
torch
.
dtype
|
None
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
out_dtype
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
quant_config
,
)
@
property
def
activation_formats
(
self
,
...
...
@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self
,
max_experts_per_worker
:
int
,
num_dispatchers
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
out_dtype
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
quant_config
,
)
super
().
__init__
(
*
args
,
**
kwargs
)
assert
max_experts_per_worker
>
0
self
.
max_experts_per_worker
=
max_experts_per_worker
self
.
num_dispatchers
=
num_dispatchers
...
...
@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
return
(
workspace1
,
workspace2
,
output
)
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
activation
:
str
=
"silu"
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert
quant_config
is
not
None
if
quant_config
.
a1_scale
is
not
None
:
assert
quant_config
.
per_act_token_quant
==
(
quant_config
.
a1_scale
.
numel
()
!=
1
)
if
quant_config
.
a2_scale
is
not
None
:
assert
quant_config
.
per_act_token_quant
==
(
quant_config
.
a2_scale
.
numel
()
!=
1
)
if
quant_config
.
w1_scale
is
not
None
:
if
quant_config
.
per_out_ch_quant
:
assert
quant_config
.
w1_scale
.
dim
()
>
1
and
quant_config
.
w1_scale
.
size
(
1
)
==
w1_q
.
size
(
1
)
else
:
assert
(
quant_config
.
w1_scale
.
dim
()
==
1
or
quant_config
.
w1_scale
.
size
(
1
)
==
1
)
num_experts
=
global_num_experts
if
global_num_experts
!=
-
1
else
w1_q
.
size
(
0
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
CutlassExpertsFp8
(
out_dtype
=
a
.
dtype
,
ab_strides1
=
ab_strides1
,
ab_strides2
=
ab_strides2
,
c_strides1
=
c_strides1
,
c_strides2
=
c_strides2
,
quant_config
=
quant_config
,
),
)
return
fn
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
...
...
vllm/model_executor/layers/fused_moe/fallback.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
class
FallbackExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
,
ABC
):
"""Base class for runtime dispatching of expert implementations."""
def
__init__
(
self
,
experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
fallback_experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
):
super
().
__init__
(
experts
.
quant_config
)
self
.
fallback_experts
=
fallback_experts
self
.
experts
=
experts
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
assert
(
self
.
fallback_experts
.
activation_formats
==
self
.
experts
.
activation_formats
)
return
self
.
fallback_experts
.
activation_formats
def
supports_chunking
(
self
)
->
bool
:
assert
(
self
.
experts
.
supports_chunking
()
==
self
.
fallback_experts
.
supports_chunking
()
)
return
(
self
.
experts
.
supports_chunking
()
and
self
.
fallback_experts
.
supports_chunking
()
)
def
supports_expert_map
(
self
)
->
bool
:
assert
(
self
.
experts
.
supports_expert_map
()
==
self
.
fallback_experts
.
supports_expert_map
()
)
return
(
self
.
experts
.
supports_expert_map
()
and
self
.
fallback_experts
.
supports_expert_map
()
)
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
e_war
=
self
.
experts
.
finalize_weight_and_reduce_impl
()
fbe_war
=
self
.
fallback_experts
.
finalize_weight_and_reduce_impl
()
is_dge_war
=
e_war
is
not
None
is_fbe_war
=
fbe_war
is
not
None
if
is_dge_war
and
is_fbe_war
:
assert
e_war
==
fbe_war
,
(
"Both implementations should agree on WeightAndReduce impls. "
f
"Got e_war:
{
e_war
}
, and fbe_war:
{
fbe_war
}
"
)
if
e_war
is
not
None
:
return
e_war
assert
fbe_war
is
not
None
return
fbe_war
@
abstractmethod
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
raise
NotImplementedError
@
abstractmethod
def
_select_experts_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
raise
NotImplementedError
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
experts
=
self
.
_select_experts_impl
(
hidden_states
,
w1
,
w2
)
experts
.
apply
(
output
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_tokens_meta
,
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
View file @
5dcd7ef1
...
...
@@ -100,7 +100,7 @@ direct_register_custom_op(
)
def
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
(
def
f
i_trtllm_fp8
_per_tensor_
moe
(
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
)
def
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
_fake
(
def
f
i_trtllm_fp8
_per_tensor_
moe
_fake
(
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op
(
op_name
=
"f
lashinfer_fused_moe
_per_tensor_
scale_fp8
"
,
op_func
=
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
,
op_name
=
"f
i_trtllm_fp8
_per_tensor_
moe
"
,
op_func
=
f
i_trtllm_fp8
_per_tensor_
moe
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
_fake
,
fake_impl
=
f
i_trtllm_fp8
_per_tensor_
moe
_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
vllm/model_executor/layers/fused_moe/oracle/__init__.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
vllm/model_executor/layers/fused_moe/oracle/fp8.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
get_flashinfer_moe_backend
,
make_fp8_moe_alpha_scales_for_fi
,
prepare_fp8_moe_layer_for_fi
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
prepare_fp8_moe_layer_for_deepgemm
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_fp8_moe_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_group_gemm_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
vllm.utils.flashinfer
import
has_flashinfer_moe
from
vllm.utils.import_utils
import
has_deep_gemm
logger
=
init_logger
(
__name__
)
class
Fp8MoeBackend
(
Enum
):
NONE
=
0
FLASHINFER_TRTLLM
=
1
FLASHINFER_CUTLASS
=
2
DEEPGEMM
=
3
MARLIN
=
4
TRITON
=
5
AITER
=
6
VLLM_CUTLASS
=
7
def
select_fp8_moe_backend
(
block_quant
:
bool
,
tp_size
:
int
,
with_lora_support
:
bool
,
is_act_and_mul
:
bool
=
True
,
allow_vllm_cutlass
:
bool
=
False
,
)
->
Fp8MoeBackend
:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# TODO(rob): in a future PR, we will query each mk for
# supported features and return the mk directly, just like
# we do for the Attention Backend.
if
with_lora_support
:
return
Fp8MoeBackend
.
TRITON
def
_make_log_backend
(
backend_name
:
str
):
return
f
"Using
{
backend_name
}
backend for FP8 MoE"
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if
(
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability_family
(
100
)
or
current_platform
.
is_device_capability
(
90
)
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()
):
backend
=
get_flashinfer_moe_backend
()
if
backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
logger
.
info_once
(
_make_log_backend
(
"FlashInfer TRTLLM"
))
if
not
is_act_and_mul
:
raise
ValueError
(
"FlashInfer TRTLLM FP8 MoE backend only supports "
"act_and_mul gate_up_project fusion. Please set "
"VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
"FlashInfer CUTLASS backend instead."
)
return
Fp8MoeBackend
.
FLASHINFER_TRTLLM
else
:
if
block_quant
and
current_platform
.
is_device_capability_family
(
100
):
raise
ValueError
(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
"FlashInfer TRTLLM backend instead."
)
logger
.
info_once
(
_make_log_backend
(
"FlashInfer CUTLASS"
))
return
Fp8MoeBackend
.
FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
if
(
current_platform
.
is_cuda
()
and
not
current_platform
.
has_device_capability
(
89
)
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
:
logger
.
info_once
(
_make_log_backend
(
"Marlin"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm
=
envs
.
VLLM_MOE_USE_DEEP_GEMM
if
not
envs
.
is_set
(
"VLLM_MOE_USE_DEEP_GEMM"
)
and
tp_size
>=
8
:
moe_use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it."
,
scope
=
"local"
,
)
use_deep_gemm
=
envs
.
VLLM_USE_DEEP_GEMM
if
not
is_deep_gemm_supported
():
use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM is disabled because the platform does not support it."
,
scope
=
"local"
,
)
if
use_deep_gemm
and
moe_use_deep_gemm
and
block_quant
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"DeepGEMM backend requested but not available."
,
scope
=
"local"
)
elif
is_deep_gemm_supported
():
logger
.
info_once
(
_make_log_backend
(
"DeepGEMM"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
DEEPGEMM
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MOE
:
logger
.
info_once
(
_make_log_backend
(
"ROCm AITER"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
AITER
if
allow_vllm_cutlass
and
not
block_quant
and
cutlass_group_gemm_supported
():
logger
.
info_once
(
_make_log_backend
(
"vLLM CUTLASS"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
VLLM_CUTLASS
# default to Triton
logger
.
info_once
(
_make_log_backend
(
"Triton"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
TRITON
def
convert_to_fp8_moe_kernel_format
(
fp8_backend
:
Fp8MoeBackend
,
layer
:
torch
.
nn
.
Module
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
|
None
,
w2_input_scale
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
block_quant
=
hasattr
(
layer
,
"weight_block_size"
)
if
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
assert
block_quant
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_deepgemm
(
w13
,
w2
,
w13_scale
,
w2_scale
,
tuple
(
layer
.
weight_block_size
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
AITER
:
w13
,
w2
=
rocm_aiter_ops
.
shuffle_weights
(
w13
,
w2
)
elif
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_marlin
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
)
elif
fp8_backend
in
[
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
]:
w13
,
w2
,
w13_scale
=
prepare_fp8_moe_layer_for_fi
(
layer
=
layer
,
w13
=
w13
,
w2
=
w2
,
w13_scale
=
w13_scale
,
w13_input_scale
=
w13_input_scale
,
w2_scale
=
w2_scale
,
w2_input_scale
=
w2_input_scale
,
is_trtllm
=
(
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
),
)
return
w13
,
w2
,
w13_scale
,
w2_scale
def
make_fp8_moe_quant_config
(
fp8_backend
:
Fp8MoeBackend
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
FusedMoEQuantConfig
|
None
:
"""
Create FusedMoEQuantConfig for the specifed FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used
at runtime by the Modular Kernel abstraction.
Note that certain kernels (e.g. Flashinfer CUTLASS) need
special Quant configs to handle non-standard inputs to
their kernel interfaces.
In a future PR, we will have this function should be
a method of the modular kernel itself.
"""
# TRTLLM does not use Modular Kernel abstraction yet.
if
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
# MARLIN is mixed precision W8A16 config.
if
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
return
fp8_w8a16_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
block_shape
=
block_shape
,
)
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
and
block_shape
is
None
:
assert
a1_scale
is
not
None
and
a2_scale
is
not
None
g1_alphas
,
g2_alphas
=
make_fp8_moe_alpha_scales_for_fi
(
w1_scale
,
a1_scale
,
w2_scale
,
a2_scale
,
)
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a1_gscale
=
(
1.0
/
a1_scale
),
a2_gscale
=
(
1.0
/
a2_scale
),
g1_alphas
=
g1_alphas
,
g2_alphas
=
g2_alphas
,
)
# All other backends use normal config.
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
def
make_fp8_moe_kernel
(
layer
:
torch
.
nn
.
Module
,
moe_quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
fp8_backend
:
Fp8MoeBackend
,
)
->
tuple
[
mk
.
FusedMoEModularKernel
,
bool
]:
# Delayed import is required since the oracle is imported
# by CPU backends which cannot import all of these experts.
# TODO: update the experts to make this not happen.
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
use_inplace
=
True
if
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
moe_quant_config
.
is_block_quantized
),
FlashInferExperts
(
out_dtype
=
layer
.
orig_dtype
,
quant_config
=
moe_quant_config
,
ep_rank
=
moe_config
.
ep_rank
,
ep_size
=
moe_config
.
ep_size
,
tp_rank
=
moe_config
.
tp_rank
,
tp_size
=
moe_config
.
tp_size
,
use_dp
=
(
moe_config
.
dp_size
>
1
),
use_deepseek_fp8_block_scale
=
moe_quant_config
.
is_block_quantized
,
),
)
use_inplace
=
False
elif
fp8_backend
==
Fp8MoeBackend
.
AITER
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
AiterExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
AiterExperts
(
quant_config
=
moe_quant_config
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MarlinExperts
(
quant_config
=
moe_quant_config
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
VLLM_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.triton_cutlass_moe
import
(
TritonOrCutlassExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonOrCutlassExperts
(
out_dtype
=
moe_config
.
in_dtype
,
e
=
layer
.
local_num_experts
,
n
=
layer
.
intermediate_size_per_partition
,
k
=
layer
.
hidden_size
,
device
=
layer
.
w13_weight
.
device
,
quant_config
=
moe_quant_config
,
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrDeepGemmExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonOrDeepGemmExperts
(
quant_config
=
moe_quant_config
),
)
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
TritonExperts
,
)
assert
fp8_backend
==
Fp8MoeBackend
.
TRITON
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
quant_config
=
moe_quant_config
),
)
return
kernel
,
use_inplace
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fallback
import
FallbackExperts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
from
vllm.platforms
import
current_platform
class
TritonOrCutlassExperts
(
FallbackExperts
):
"""Cutlass with fallback to Triton for low latency shapes on SM100."""
def
__init__
(
self
,
e
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
quant_config
:
FusedMoEQuantConfig
,
device
:
torch
.
dtype
,
):
self
.
is_sm100
=
current_platform
.
has_device_capability
(
100
)
super
().
__init__
(
experts
=
CutlassExpertsFp8
(
e
,
n
,
k
,
out_dtype
,
quant_config
,
device
),
fallback_experts
=
TritonExperts
(
quant_config
),
)
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Small batch fallback for sm100.
if
self
.
is_sm100
and
M
<=
8
:
return
self
.
fallback_experts
.
workspace_shapes
(
M
,
N
,
K
,
topk
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
else
:
return
self
.
experts
.
workspace_shapes
(
M
,
N
,
K
,
topk
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
def
_select_experts_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
# Small batch fallback for sm100.
if
self
.
is_sm100
and
hidden_states
.
shape
[
0
]
<=
8
:
return
self
.
fallback_experts
else
:
return
self
.
experts
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
View file @
5dcd7ef1
...
...
@@ -10,78 +10,22 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm
,
_valid_deep_gemm_shape
,
)
from
vllm.model_executor.layers.fused_moe.fallback
import
FallbackExperts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
from
vllm.utils.deep_gemm
import
(
get_mk_alignment_for_contiguous_layout
,
is_deep_gemm_e8m0_used
,
)
class
TritonOrDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
,
allow_deep_gemm
:
bool
=
False
,
):
super
().
__init__
(
quant_config
)
self
.
triton_expert
=
TritonExperts
(
quant_config
)
self
.
allow_deep_gemm
=
(
allow_deep_gemm
and
self
.
quant_config
.
use_fp8_w8a8
and
self
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
)
self
.
deep_gemm_expert
=
(
DeepGemmExperts
(
self
.
quant_config
)
if
self
.
allow_deep_gemm
else
None
)
class
TritonOrDeepGemmExperts
(
FallbackExperts
):
"""DeepGemm with fallback to Triton for low latency shapes."""
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
assert
(
self
.
deep_gemm_expert
is
None
or
self
.
triton_expert
.
activation_formats
==
self
.
deep_gemm_expert
.
activation_formats
)
return
self
.
triton_expert
.
activation_formats
def
supports_chunking
(
self
)
->
bool
:
dge
=
self
.
deep_gemm_expert
te
=
self
.
triton_expert
return
(
dge
is
None
or
dge
.
supports_chunking
())
and
(
te
is
None
or
te
.
supports_chunking
()
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
super
().
__init__
(
experts
=
DeepGemmExperts
(
quant_config
),
fallback_experts
=
TritonExperts
(
quant_config
),
)
def
supports_expert_map
(
self
)
->
bool
:
dge
=
self
.
deep_gemm_expert
te
=
self
.
triton_expert
return
(
dge
is
None
or
dge
.
supports_expert_map
())
and
(
te
is
None
or
te
.
supports_expert_map
()
)
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
dge
=
self
.
deep_gemm_expert
te
=
self
.
triton_expert
dge_war
=
dge
.
finalize_weight_and_reduce_impl
()
if
dge
else
None
te_war
=
te
.
finalize_weight_and_reduce_impl
()
if
te
else
None
is_dge_war
=
dge_war
is
not
None
is_te_war
=
te_war
is
not
None
if
is_dge_war
and
is_te_war
:
assert
dge_war
==
te_war
,
(
"Both implementations should agree on WeightAndReduce impls. "
f
"Got dge_war:
{
dge_war
}
, and te_war:
{
te_war
}
"
)
if
dge_war
is
not
None
:
return
dge_war
assert
te_war
is
not
None
return
te_war
def
workspace_shapes
(
self
,
M
:
int
,
...
...
@@ -95,11 +39,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if
self
.
allow_deep_gemm
and
(
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm_shape
(
M
,
N
,
K
)
):
assert
self
.
deep_gemm_expert
is
not
None
return
self
.
deep_gemm_expert
.
workspace_shapes
(
if
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm_shape
(
M
,
N
,
K
):
return
self
.
experts
.
workspace_shapes
(
M
,
N
,
K
,
...
...
@@ -109,7 +50,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
,
)
else
:
return
self
.
triton
_expert
.
workspace_shapes
(
return
self
.
fallback
_expert
s
.
workspace_shapes
(
M
,
N
,
K
,
...
...
@@ -119,45 +60,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
,
)
def
ap
pl
y
(
def
_select_experts_im
pl
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
use_deep_gemm
=
self
.
allow_deep_gemm
and
(
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
)
)
experts
=
self
.
deep_gemm_expert
if
use_deep_gemm
else
self
.
triton_expert
assert
experts
is
not
None
experts
.
apply
(
output
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_tokens_meta
,
apply_router_weight_on_input
,
)
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
if
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
):
return
self
.
experts
else
:
return
self
.
fallback_experts
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
5dcd7ef1
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/quantization/fp8.py
View file @
5dcd7ef1
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/quantization/modelopt.py
View file @
5dcd7ef1
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
5dcd7ef1
...
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_
moe_fp8
_layer_for_marlin
,
prepare_
fp8_moe
_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
(
OCP_MX_BLOCK_SIZE
,
...
...
@@ -315,8 +315,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
elif
self
.
use_marlin
:
(
workspace
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
)
=
(
prepare_
moe_fp8
_layer_for_marlin
(
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
=
(
prepare_
fp8_moe
_layer_for_marlin
(
layer
,
layer
.
w13_weight
,
layer
.
w2_weight
,
...
...
@@ -324,7 +324,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w2_weight_scale
,
)
)
layer
.
workspace
=
workspace
# TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL.
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
5dcd7ef1
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
...
...
@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
)
def
rotate_
flashinfer_fp8_moe_weights
(
def
rotate_
weights_for_fi_trtllm_fp8_per_tensor_moe
(
gemm1_weights
:
torch
.
Tensor
,
gemm2_weights
:
torch
.
Tensor
):
"""Shuffle weights for for FI TRT-LLM Format"""
from
flashinfer
import
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
epilogue_tile_m
=
128
...
...
@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
def
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
:
torch
.
nn
.
Module
,
w13_
weight_
scale
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
,
w2_
weight_
scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
,
)
->
None
:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas
,
g2_alphas
=
make_fp8_moe_alpha_scales_for_fi
(
w13_scale
=
w13_
weight_
scale
,
w13_scale
=
w13_scale
,
w13_input_scale
=
w13_input_scale
,
w2_scale
=
w2_
weight_
scale
,
w2_scale
=
w2_scale
,
w2_input_scale
=
w2_input_scale
,
)
layer
.
w2_input_scale_inv
=
1.0
/
w2_input_scale
...
...
@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
layer
.
output2_scales_scalar
=
g2_alphas
def
apply_f
lashinfer
_per_tensor_
scale_fp8
(
def
apply_f
i_trtllm_fp8
_per_tensor_
moe
(
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
from
vllm.model_executor.models.llama4
import
Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert
(
hasattr
(
layer
,
"output1_scales_scalar"
)
and
hasattr
(
layer
,
"output1_scales_gate_scalar"
)
and
hasattr
(
layer
,
"output2_scales_scalar"
)
)
assert
layer
.
custom_routing_function
==
Llama4MoE
.
custom_routing_function
,
(
"FusedMoE flashinfer kernels are only supported for Llama4"
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert
(
hasattr
(
layer
,
"output1_scales_scalar"
)
and
hasattr
(
layer
,
"output1_scales_gate_scalar"
)
and
hasattr
(
layer
,
"output2_scales_scalar"
)
)
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_per_tensor_scale_fp8
(
is_llama4
=
layer
.
custom_routing_function
==
Llama4MoE
.
custom_routing_function
assert
is_llama4
,
"FusedMoE flashinfer kernels are only supported for Llama4"
return
torch
.
ops
.
vllm
.
fi_trtllm_fp8_per_tensor_moe
(
routing_logits
=
router_logits
,
routing_bias
=
routing_bias
,
hidden_states
=
hidden_states
,
...
...
@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
)
def
flashinfer_cutlass_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_deepseek_fp8_block_scale
:
bool
=
False
,
moe
:
FusedMoEConfig
|
None
=
None
,
)
->
torch
.
Tensor
:
quant_config
=
layer
.
quant_method
.
get_fused_moe_quant_config
(
layer
)
assert
quant_config
is
not
None
# Construct modular kernel with block-scale support when requested.
fused_experts
=
mk
.
FusedMoEModularKernel
(
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
moe
=
moe
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
),
select_cutlass_fp8_gemm_impl
(
moe
=
moe
,
quant_config
=
quant_config
,
out_dtype
=
hidden_states
.
dtype
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
),
moe_parallel_config
=
layer
.
moe_parallel_config
,
)
return
fused_experts
(
hidden_states
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
def
get_flashinfer_moe_backend
()
->
FlashinferMoeBackend
:
backend_map
=
{
"throughput"
:
FlashinferMoeBackend
.
CUTLASS
,
...
...
@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
FlashinferMoeBackend
.
TENSORRT_LLM
,
)
return
backend
in
backends_supporting_global_sf
def
align_fp8_moe_weights_for_fi
(
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
is_act_and_mul
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts
,
hidden_size
,
intermediate
=
w2
.
shape
min_alignment
=
16
padded_intermediate
=
round_up
(
intermediate
,
min_alignment
)
if
padded_intermediate
==
intermediate
:
return
w13
,
w2
,
intermediate
logger
.
info_once
(
"Padding intermediate size from %d to %d for up/down projection weights."
,
intermediate
,
padded_intermediate
,
scope
=
"local"
,
)
up_mult
=
2
if
is_act_and_mul
else
1
padded_gate_up_dim
=
up_mult
*
padded_intermediate
# Pad w13 and w2 along its intermediate dimension.
padded_w13
=
w13
.
new_zeros
((
num_experts
,
padded_gate_up_dim
,
hidden_size
))
padded_w13
[:,
:
w13
.
shape
[
1
],
:]
=
w13
padded_w2
=
w2
.
new_zeros
((
num_experts
,
hidden_size
,
padded_intermediate
))
padded_w2
[:,
:,
:
intermediate
]
=
w2
return
padded_w13
,
padded_w2
,
padded_intermediate
def
prepare_fp8_moe_layer_for_fi
(
layer
:
torch
.
nn
.
Module
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
|
None
,
w2_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
|
None
,
is_trtllm
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
Note that for trtllm we update the model state dict
with the scale format needed for these kernels.
Note that for per-tensor, we update the layer's
intermediate size if the weights needed padding.
"""
assert
hasattr
(
layer
.
moe_config
,
"is_act_and_mul"
)
block_quant
=
(
hasattr
(
layer
,
"weight_block_size"
)
and
layer
.
weight_block_size
is
not
None
)
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
if
not
block_quant
:
w13
,
w2
,
new_intermediate
=
align_fp8_moe_weights_for_fi
(
w13
,
w2
,
layer
.
moe_config
.
is_act_and_mul
,
)
layer
.
intermediate_size_per_partition
=
new_intermediate
# FI kernels require W31 layout rather than W13.
if
layer
.
moe_config
.
is_act_and_mul
:
w13
=
swap_w13_to_w31
(
w13
)
if
block_quant
:
w13_scale
=
swap_w13_to_w31
(
w13_scale
)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
if
is_trtllm
and
not
block_quant
:
assert
w13_input_scale
is
not
None
assert
w2_input_scale
is
not
None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe
(
w13
,
w2
)
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
,
w13_scale
=
w13_scale
,
w13_input_scale
=
w13_input_scale
,
w2_scale
=
w2_scale
,
w2_input_scale
=
w2_input_scale
,
)
return
w13
,
w2
,
w13_scale
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
5dcd7ef1
...
...
@@ -21,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
all_close_1d
,
per_tensor_dequantize
,
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
...
...
@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block(
return
wq
,
dg_ws
def
prepare_fp8_moe_layer_for_deepgemm
(
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
block_shape
:
tuple
[
int
],
):
w13
,
w13_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
w13
,
ws
=
w13_scale
,
quant_block_shape
=
block_shape
,
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
w2
,
w2_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
w2
,
ws
=
w2_scale
,
quant_block_shape
=
block_shape
,
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
return
w13
,
w2
,
w13_scale
,
w2_scale
def
_maybe_pad_fp8_weight
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
...
...
@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
replace_parameter
(
layer
,
scale_attr
,
dg_weight_scale
)
def
expert_weight_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
assert
x
.
dim
()
==
3
b
,
m
,
n
=
x
.
shape
return
x
.
stride
(
0
)
==
m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
m
def
process_fp8_weight_tensor_strategy_moe
(
weight
:
torch
.
Tensor
,
weight_scales
:
torch
.
Tensor
,
shard_size
:
int
,
num_experts
:
int
,
is_act_and_mul
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Process moe weights for tensor-wise quantization strategy."""
max_scales
=
weight_scales
.
max
(
dim
=
1
).
values
# For w1 case (i.e. not w13): just collapse the last dim since
# there is already just one scale per expert in this case.
if
not
is_act_and_mul
:
assert
weight_scales
.
shape
[
1
]
==
1
return
weight
,
weight_scales
.
max
()
# For w13 case (common): require single scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
for
expert_id
in
range
(
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
weight_scales
[
expert_id
][
shard_id
],
)
weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_scales
[
expert_id
]
)
start
+=
shard_size
return
weight
,
max_scales
def
process_fp8_input_tensor_strategy_moe
(
w13_input_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Process moe input scales for tensor-wise quantization strategy."""
if
not
all_close_1d
(
w13_input_scale
)
or
not
all_close_1d
(
w2_input_scale
):
logger
.
info_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
return
w13_input_scale
.
max
(),
w2_input_scale
.
max
()
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
5dcd7ef1
...
...
@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8:
return
_quant_fp8_method
def
get_marlin_input_dtype
(
prefix
):
def
get_marlin_input_dtype
(
prefix
:
str
|
None
=
None
):
if
envs
.
VLLM_MARLIN_INPUT_DTYPE
is
None
:
return
elif
envs
.
VLLM_MARLIN_INPUT_DTYPE
.
lower
()
==
"int8"
:
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
5dcd7ef1
...
...
@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
USE_FP32_REDUCE_DEFAULT
,
get_marlin_input_dtype
,
marlin_make_workspace_new
,
marlin_permute_bias
,
marlin_permute_scales
,
...
...
@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin(
replace_parameter
(
layer
,
"bias"
,
bias
)
def
prepare_
moe_fp8
_layer_for_marlin
(
def
prepare_
fp8_moe
_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
w13_weight
:
torch
.
Tensor
,
w2_weight
:
torch
.
Tensor
,
w13_weight_scale
:
torch
.
Tensor
,
w2_weight_scale
:
torch
.
Tensor
,
input_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
# workspace
torch
.
Tensor
,
# w13_weight
torch
.
Tensor
,
# w2_weight
torch
.
Tensor
,
# w13_weight_scale
torch
.
Tensor
,
# w2_weight_scale
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Shuffle weights and scales into marlin format.
Note that this function has the side effect of adding a `workspace`
attribute to the layer. This `workspace` does not need to be
registered as a Parameter as it is not used during weight reloading.
"""
logger
.
warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
input_dtype
=
get_marlin_input_dtype
()
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
raise
NotImplementedError
(
"Marlin W8A8 is not supported."
)
...
...
@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE
device
=
layer
.
w13_weight
.
device
workspace
=
marlin_make_workspace_new
(
device
,
4
)
# NOTE(rob): we do not need to register the workspace as a param
# because it is not used as part of the weight reloading process.
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
# WEIGHT
...
...
@@ -310,13 +315,7 @@ def prepare_moe_fp8_layer_for_marlin(
w13_weight_scale
=
permute_scales
(
w13_weight_scale
,
"w13"
)
w2_weight_scale
=
permute_scales
(
w2_weight_scale
,
"w2"
)
return
(
workspace
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
)
return
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
def
pack_fp8_to_int32
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment