Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5dcd7ef1
"vscode:/vscode.git/clone" did not exist on "6d5cfab5fdf1a2450a026c60109fa699bccd7ca8"
Unverified
Commit
5dcd7ef1
authored
Jan 07, 2026
by
Robert Shaw
Committed by
GitHub
Jan 07, 2026
Browse files
[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
parent
ffc0a279
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1252 additions
and
1360 deletions
+1252
-1360
tests/kernels/moe/test_flashinfer.py
tests/kernels/moe/test_flashinfer.py
+37
-12
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+12
-7
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+0
-2
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+17
-145
vllm/model_executor/layers/fused_moe/fallback.py
vllm/model_executor/layers/fused_moe/fallback.py
+126
-0
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
.../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+5
-5
vllm/model_executor/layers/fused_moe/oracle/__init__.py
vllm/model_executor/layers/fused_moe/oracle/__init__.py
+2
-0
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+358
-0
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
+75
-0
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+16
-107
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+135
-355
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+112
-384
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+147
-264
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+3
-4
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+119
-53
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+71
-4
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+1
-1
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+16
-17
No files found.
tests/kernels/moe/test_flashinfer.py
View file @
5dcd7ef1
...
...
@@ -11,12 +11,17 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_flashinfer_per_tensor_scale_fp8
,
flashinfer_cutlass_moe_fp8
,
apply_fi_trtllm_fp8_per_tensor_moe
,
register_scales_for_trtllm_fp8_per_tensor_moe
,
rotate_
flashinfer_fp8_moe_weights
,
rotate_
weights_for_fi_trtllm_fp8_per_tensor_moe
,
swap_w13_to_w31
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
input_to_float8
...
...
@@ -103,6 +108,7 @@ class TestData:
w2_quantized
,
w2_weight_scale
=
quant_fp8_per_tensor_batches
(
w2
)
layer
=
torch
.
nn
.
Module
()
layer
.
orig_dtype
=
torch
.
bfloat16
layer
.
w13_weight
=
w13_quantized
.
clone
()
layer
.
w2_weight
=
w2_quantized
.
clone
()
layer
.
w13_input_scale
=
a1_scale
...
...
@@ -115,10 +121,10 @@ class TestData:
pcp_size
=
1
,
dp_size
=
1
,
ep_size
=
1
,
tp_rank
=
1
,
pcp_rank
=
1
,
dp_rank
=
1
,
ep_rank
=
1
,
tp_rank
=
0
,
pcp_rank
=
0
,
dp_rank
=
0
,
ep_rank
=
0
,
use_ep
=
False
,
all2all_backend
=
"naive"
,
)
...
...
@@ -126,7 +132,9 @@ class TestData:
# flashinfer expects swapped rows for w13
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
if
is_trtllm
:
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe
(
layer
.
w13_weight
,
layer
.
w2_weight
)
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
,
layer
.
w13_weight_scale
,
...
...
@@ -199,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config
=
quant_config
,
)
flashinfer_output
=
apply_f
lashinfer
_per_tensor_
scale_fp8
(
flashinfer_output
=
apply_f
i_trtllm_fp8
_per_tensor_
moe
(
layer
=
td
.
layer
,
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
...
...
@@ -277,17 +285,34 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td
.
layer
.
get_fused_moe_quant_config
=
get_fused_moe_quant_config
td
.
layer
.
quant_method
=
td
.
layer
flashinfer_cutlass_output
=
flashinfer_cutlass_moe_fp8
(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
quant_config
.
is_block_quantized
),
FlashInferExperts
(
out_dtype
=
td
.
layer
.
orig_dtype
,
quant_config
=
quant_config
,
ep_rank
=
td
.
layer
.
moe_parallel_config
.
ep_rank
,
ep_size
=
td
.
layer
.
moe_parallel_config
.
ep_size
,
tp_rank
=
td
.
layer
.
moe_parallel_config
.
tp_rank
,
tp_size
=
td
.
layer
.
moe_parallel_config
.
tp_size
,
use_dp
=
False
,
use_deepseek_fp8_block_scale
=
False
,
),
)
flashinfer_cutlass_output
=
kernel
(
td
.
hidden_states
,
td
.
layer
,
td
.
layer
.
w13_weight
,
td
.
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
False
,
activation
=
activation
,
global_num_experts
=
e
,
expert_map
=
None
,
apply_router_weight_on_input
=
True
,
)
torch
.
testing
.
assert_close
(
output
,
flashinfer_cutlass_output
,
atol
=
5.5e-2
,
rtol
=
1e-2
)
tests/quantization/test_fp8.py
View file @
5dcd7ef1
...
...
@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config
,
Fp8KVCacheMethod
,
Fp8LinearMethod
,
Fp8MoeBackend
,
Fp8MoEMethod
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -278,8 +277,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@
pytest
.
mark
.
parametrize
(
"use_marlin"
,
[
False
])
# skip True
def
test_fp8_reloading
(
method_cls
,
is_checkpoint_fp8_serialized
,
weight_block_size
,
use_marlin
,
dist_init
method_cls
,
is_checkpoint_fp8_serialized
,
weight_block_size
,
use_marlin
,
dist_init
,
monkeypatch
,
):
# NOTE(rob): this test fails when using DeepGEMM because the
# shapes are invalid. Previously the test was passing because
# we set fp8_backend to None, which sidestepped the issue.
monkeypatch
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
)
if
is_checkpoint_fp8_serialized
is
False
:
pytest
.
skip
(
"FP8 weight reloading does not support online quantization"
)
...
...
@@ -307,6 +316,7 @@ def test_fp8_reloading(
params_dtype
=
torch
.
bfloat16
,
weight_loader
=
default_weight_loader
,
)
method
.
use_marlin
=
use_marlin
else
:
layer
=
FusedMoE
(
...
...
@@ -325,11 +335,6 @@ def test_fp8_reloading(
weight_loader
=
default_weight_loader
,
)
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method
.
use_marlin
=
use_marlin
method
.
fp8_backend
=
Fp8MoeBackend
.
MARLIN
if
use_marlin
else
None
# capture weights format during loading
original_metadata
=
[
(
name
,
param
.
shape
,
getattr
(
param
,
"weight_loader"
,
default_weight_loader
))
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
5dcd7ef1
...
...
@@ -73,7 +73,6 @@ if HAS_TRITON:
CutlassExpertsFp8
,
CutlassExpertsW4A8Fp8
,
cutlass_moe_fp4
,
cutlass_moe_fp8
,
cutlass_moe_w4a8_fp8
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
...
...
@@ -96,7 +95,6 @@ if HAS_TRITON:
"fused_experts"
,
"get_config_file_name"
,
"GroupedTopk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
"cutlass_moe_w4a8_fp8"
,
"CutlassExpertsFp8"
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
5dcd7ef1
...
...
@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
e
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
device
:
torch
.
dtype
,
):
assert
quant_config
.
use_fp8_w8a8
super
().
__init__
(
quant_config
)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
ab_strides1_c_strides2
=
torch
.
full
((
e
,),
k
,
device
=
device
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,),
n
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,),
2
*
n
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
out_dtype
=
out_dtype
self
.
ab_strides1
=
ab_strides1
self
.
ab_strides1
=
ab_strides1
_c_strides2
self
.
ab_strides2
=
ab_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
c_strides2
self
.
c_strides2
=
ab_strides1_
c_strides2
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
...
...
@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
def
__init__
(
self
,
out_dtype
:
torch
.
dtype
|
None
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
out_dtype
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
quant_config
,
)
@
property
def
activation_formats
(
self
,
...
...
@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self
,
max_experts_per_worker
:
int
,
num_dispatchers
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
out_dtype
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
quant_config
,
)
super
().
__init__
(
*
args
,
**
kwargs
)
assert
max_experts_per_worker
>
0
self
.
max_experts_per_worker
=
max_experts_per_worker
self
.
num_dispatchers
=
num_dispatchers
...
...
@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
return
(
workspace1
,
workspace2
,
output
)
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
activation
:
str
=
"silu"
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert
quant_config
is
not
None
if
quant_config
.
a1_scale
is
not
None
:
assert
quant_config
.
per_act_token_quant
==
(
quant_config
.
a1_scale
.
numel
()
!=
1
)
if
quant_config
.
a2_scale
is
not
None
:
assert
quant_config
.
per_act_token_quant
==
(
quant_config
.
a2_scale
.
numel
()
!=
1
)
if
quant_config
.
w1_scale
is
not
None
:
if
quant_config
.
per_out_ch_quant
:
assert
quant_config
.
w1_scale
.
dim
()
>
1
and
quant_config
.
w1_scale
.
size
(
1
)
==
w1_q
.
size
(
1
)
else
:
assert
(
quant_config
.
w1_scale
.
dim
()
==
1
or
quant_config
.
w1_scale
.
size
(
1
)
==
1
)
num_experts
=
global_num_experts
if
global_num_experts
!=
-
1
else
w1_q
.
size
(
0
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
CutlassExpertsFp8
(
out_dtype
=
a
.
dtype
,
ab_strides1
=
ab_strides1
,
ab_strides2
=
ab_strides2
,
c_strides1
=
c_strides1
,
c_strides2
=
c_strides2
,
quant_config
=
quant_config
,
),
)
return
fn
(
a
,
w1_q
,
w2_q
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
...
...
vllm/model_executor/layers/fused_moe/fallback.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
class
FallbackExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
,
ABC
):
"""Base class for runtime dispatching of expert implementations."""
def
__init__
(
self
,
experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
fallback_experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
):
super
().
__init__
(
experts
.
quant_config
)
self
.
fallback_experts
=
fallback_experts
self
.
experts
=
experts
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
assert
(
self
.
fallback_experts
.
activation_formats
==
self
.
experts
.
activation_formats
)
return
self
.
fallback_experts
.
activation_formats
def
supports_chunking
(
self
)
->
bool
:
assert
(
self
.
experts
.
supports_chunking
()
==
self
.
fallback_experts
.
supports_chunking
()
)
return
(
self
.
experts
.
supports_chunking
()
and
self
.
fallback_experts
.
supports_chunking
()
)
def
supports_expert_map
(
self
)
->
bool
:
assert
(
self
.
experts
.
supports_expert_map
()
==
self
.
fallback_experts
.
supports_expert_map
()
)
return
(
self
.
experts
.
supports_expert_map
()
and
self
.
fallback_experts
.
supports_expert_map
()
)
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
e_war
=
self
.
experts
.
finalize_weight_and_reduce_impl
()
fbe_war
=
self
.
fallback_experts
.
finalize_weight_and_reduce_impl
()
is_dge_war
=
e_war
is
not
None
is_fbe_war
=
fbe_war
is
not
None
if
is_dge_war
and
is_fbe_war
:
assert
e_war
==
fbe_war
,
(
"Both implementations should agree on WeightAndReduce impls. "
f
"Got e_war:
{
e_war
}
, and fbe_war:
{
fbe_war
}
"
)
if
e_war
is
not
None
:
return
e_war
assert
fbe_war
is
not
None
return
fbe_war
@
abstractmethod
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
raise
NotImplementedError
@
abstractmethod
def
_select_experts_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
raise
NotImplementedError
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
experts
=
self
.
_select_experts_impl
(
hidden_states
,
w1
,
w2
)
experts
.
apply
(
output
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_tokens_meta
,
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
View file @
5dcd7ef1
...
...
@@ -100,7 +100,7 @@ direct_register_custom_op(
)
def
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
(
def
f
i_trtllm_fp8
_per_tensor_
moe
(
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
)
def
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
_fake
(
def
f
i_trtllm_fp8
_per_tensor_
moe
_fake
(
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op
(
op_name
=
"f
lashinfer_fused_moe
_per_tensor_
scale_fp8
"
,
op_func
=
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
,
op_name
=
"f
i_trtllm_fp8
_per_tensor_
moe
"
,
op_func
=
f
i_trtllm_fp8
_per_tensor_
moe
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
f
lashinfer_fused_moe
_per_tensor_
scale_fp8
_fake
,
fake_impl
=
f
i_trtllm_fp8
_per_tensor_
moe
_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
vllm/model_executor/layers/fused_moe/oracle/__init__.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
vllm/model_executor/layers/fused_moe/oracle/fp8.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
get_flashinfer_moe_backend
,
make_fp8_moe_alpha_scales_for_fi
,
prepare_fp8_moe_layer_for_fi
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
prepare_fp8_moe_layer_for_deepgemm
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_fp8_moe_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_group_gemm_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
vllm.utils.flashinfer
import
has_flashinfer_moe
from
vllm.utils.import_utils
import
has_deep_gemm
logger
=
init_logger
(
__name__
)
class
Fp8MoeBackend
(
Enum
):
NONE
=
0
FLASHINFER_TRTLLM
=
1
FLASHINFER_CUTLASS
=
2
DEEPGEMM
=
3
MARLIN
=
4
TRITON
=
5
AITER
=
6
VLLM_CUTLASS
=
7
def
select_fp8_moe_backend
(
block_quant
:
bool
,
tp_size
:
int
,
with_lora_support
:
bool
,
is_act_and_mul
:
bool
=
True
,
allow_vllm_cutlass
:
bool
=
False
,
)
->
Fp8MoeBackend
:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# TODO(rob): in a future PR, we will query each mk for
# supported features and return the mk directly, just like
# we do for the Attention Backend.
if
with_lora_support
:
return
Fp8MoeBackend
.
TRITON
def
_make_log_backend
(
backend_name
:
str
):
return
f
"Using
{
backend_name
}
backend for FP8 MoE"
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if
(
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability_family
(
100
)
or
current_platform
.
is_device_capability
(
90
)
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()
):
backend
=
get_flashinfer_moe_backend
()
if
backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
logger
.
info_once
(
_make_log_backend
(
"FlashInfer TRTLLM"
))
if
not
is_act_and_mul
:
raise
ValueError
(
"FlashInfer TRTLLM FP8 MoE backend only supports "
"act_and_mul gate_up_project fusion. Please set "
"VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
"FlashInfer CUTLASS backend instead."
)
return
Fp8MoeBackend
.
FLASHINFER_TRTLLM
else
:
if
block_quant
and
current_platform
.
is_device_capability_family
(
100
):
raise
ValueError
(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
"FlashInfer TRTLLM backend instead."
)
logger
.
info_once
(
_make_log_backend
(
"FlashInfer CUTLASS"
))
return
Fp8MoeBackend
.
FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
if
(
current_platform
.
is_cuda
()
and
not
current_platform
.
has_device_capability
(
89
)
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
:
logger
.
info_once
(
_make_log_backend
(
"Marlin"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm
=
envs
.
VLLM_MOE_USE_DEEP_GEMM
if
not
envs
.
is_set
(
"VLLM_MOE_USE_DEEP_GEMM"
)
and
tp_size
>=
8
:
moe_use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it."
,
scope
=
"local"
,
)
use_deep_gemm
=
envs
.
VLLM_USE_DEEP_GEMM
if
not
is_deep_gemm_supported
():
use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM is disabled because the platform does not support it."
,
scope
=
"local"
,
)
if
use_deep_gemm
and
moe_use_deep_gemm
and
block_quant
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"DeepGEMM backend requested but not available."
,
scope
=
"local"
)
elif
is_deep_gemm_supported
():
logger
.
info_once
(
_make_log_backend
(
"DeepGEMM"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
DEEPGEMM
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MOE
:
logger
.
info_once
(
_make_log_backend
(
"ROCm AITER"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
AITER
if
allow_vllm_cutlass
and
not
block_quant
and
cutlass_group_gemm_supported
():
logger
.
info_once
(
_make_log_backend
(
"vLLM CUTLASS"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
VLLM_CUTLASS
# default to Triton
logger
.
info_once
(
_make_log_backend
(
"Triton"
),
scope
=
"local"
)
return
Fp8MoeBackend
.
TRITON
def
convert_to_fp8_moe_kernel_format
(
fp8_backend
:
Fp8MoeBackend
,
layer
:
torch
.
nn
.
Module
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
|
None
,
w2_input_scale
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
block_quant
=
hasattr
(
layer
,
"weight_block_size"
)
if
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
assert
block_quant
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_deepgemm
(
w13
,
w2
,
w13_scale
,
w2_scale
,
tuple
(
layer
.
weight_block_size
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
AITER
:
w13
,
w2
=
rocm_aiter_ops
.
shuffle_weights
(
w13
,
w2
)
elif
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_marlin
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
)
elif
fp8_backend
in
[
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
]:
w13
,
w2
,
w13_scale
=
prepare_fp8_moe_layer_for_fi
(
layer
=
layer
,
w13
=
w13
,
w2
=
w2
,
w13_scale
=
w13_scale
,
w13_input_scale
=
w13_input_scale
,
w2_scale
=
w2_scale
,
w2_input_scale
=
w2_input_scale
,
is_trtllm
=
(
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
),
)
return
w13
,
w2
,
w13_scale
,
w2_scale
def
make_fp8_moe_quant_config
(
fp8_backend
:
Fp8MoeBackend
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
FusedMoEQuantConfig
|
None
:
"""
Create FusedMoEQuantConfig for the specifed FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used
at runtime by the Modular Kernel abstraction.
Note that certain kernels (e.g. Flashinfer CUTLASS) need
special Quant configs to handle non-standard inputs to
their kernel interfaces.
In a future PR, we will have this function should be
a method of the modular kernel itself.
"""
# TRTLLM does not use Modular Kernel abstraction yet.
if
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
# MARLIN is mixed precision W8A16 config.
if
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
return
fp8_w8a16_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
block_shape
=
block_shape
,
)
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
and
block_shape
is
None
:
assert
a1_scale
is
not
None
and
a2_scale
is
not
None
g1_alphas
,
g2_alphas
=
make_fp8_moe_alpha_scales_for_fi
(
w1_scale
,
a1_scale
,
w2_scale
,
a2_scale
,
)
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a1_gscale
=
(
1.0
/
a1_scale
),
a2_gscale
=
(
1.0
/
a2_scale
),
g1_alphas
=
g1_alphas
,
g2_alphas
=
g2_alphas
,
)
# All other backends use normal config.
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
def
make_fp8_moe_kernel
(
layer
:
torch
.
nn
.
Module
,
moe_quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
fp8_backend
:
Fp8MoeBackend
,
)
->
tuple
[
mk
.
FusedMoEModularKernel
,
bool
]:
# Delayed import is required since the oracle is imported
# by CPU backends which cannot import all of these experts.
# TODO: update the experts to make this not happen.
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
use_inplace
=
True
if
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
moe_quant_config
.
is_block_quantized
),
FlashInferExperts
(
out_dtype
=
layer
.
orig_dtype
,
quant_config
=
moe_quant_config
,
ep_rank
=
moe_config
.
ep_rank
,
ep_size
=
moe_config
.
ep_size
,
tp_rank
=
moe_config
.
tp_rank
,
tp_size
=
moe_config
.
tp_size
,
use_dp
=
(
moe_config
.
dp_size
>
1
),
use_deepseek_fp8_block_scale
=
moe_quant_config
.
is_block_quantized
,
),
)
use_inplace
=
False
elif
fp8_backend
==
Fp8MoeBackend
.
AITER
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
AiterExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
AiterExperts
(
quant_config
=
moe_quant_config
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MarlinExperts
(
quant_config
=
moe_quant_config
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
VLLM_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.triton_cutlass_moe
import
(
TritonOrCutlassExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonOrCutlassExperts
(
out_dtype
=
moe_config
.
in_dtype
,
e
=
layer
.
local_num_experts
,
n
=
layer
.
intermediate_size_per_partition
,
k
=
layer
.
hidden_size
,
device
=
layer
.
w13_weight
.
device
,
quant_config
=
moe_quant_config
,
),
)
elif
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrDeepGemmExperts
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonOrDeepGemmExperts
(
quant_config
=
moe_quant_config
),
)
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
TritonExperts
,
)
assert
fp8_backend
==
Fp8MoeBackend
.
TRITON
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
quant_config
=
moe_quant_config
),
)
return
kernel
,
use_inplace
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
0 → 100644
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fallback
import
FallbackExperts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
from
vllm.platforms
import
current_platform
class
TritonOrCutlassExperts
(
FallbackExperts
):
"""Cutlass with fallback to Triton for low latency shapes on SM100."""
def
__init__
(
self
,
e
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
quant_config
:
FusedMoEQuantConfig
,
device
:
torch
.
dtype
,
):
self
.
is_sm100
=
current_platform
.
has_device_capability
(
100
)
super
().
__init__
(
experts
=
CutlassExpertsFp8
(
e
,
n
,
k
,
out_dtype
,
quant_config
,
device
),
fallback_experts
=
TritonExperts
(
quant_config
),
)
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# Small batch fallback for sm100.
if
self
.
is_sm100
and
M
<=
8
:
return
self
.
fallback_experts
.
workspace_shapes
(
M
,
N
,
K
,
topk
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
else
:
return
self
.
experts
.
workspace_shapes
(
M
,
N
,
K
,
topk
,
global_num_experts
,
local_num_experts
,
expert_tokens_meta
,
)
def
_select_experts_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
# Small batch fallback for sm100.
if
self
.
is_sm100
and
hidden_states
.
shape
[
0
]
<=
8
:
return
self
.
fallback_experts
else
:
return
self
.
experts
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
View file @
5dcd7ef1
...
...
@@ -10,78 +10,22 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm
,
_valid_deep_gemm_shape
,
)
from
vllm.model_executor.layers.fused_moe.fallback
import
FallbackExperts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
from
vllm.utils.deep_gemm
import
(
get_mk_alignment_for_contiguous_layout
,
is_deep_gemm_e8m0_used
,
)
class
TritonOrDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
,
allow_deep_gemm
:
bool
=
False
,
):
super
().
__init__
(
quant_config
)
self
.
triton_expert
=
TritonExperts
(
quant_config
)
self
.
allow_deep_gemm
=
(
allow_deep_gemm
and
self
.
quant_config
.
use_fp8_w8a8
and
self
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
)
self
.
deep_gemm_expert
=
(
DeepGemmExperts
(
self
.
quant_config
)
if
self
.
allow_deep_gemm
else
None
)
class
TritonOrDeepGemmExperts
(
FallbackExperts
):
"""DeepGemm with fallback to Triton for low latency shapes."""
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
assert
(
self
.
deep_gemm_expert
is
None
or
self
.
triton_expert
.
activation_formats
==
self
.
deep_gemm_expert
.
activation_formats
)
return
self
.
triton_expert
.
activation_formats
def
supports_chunking
(
self
)
->
bool
:
dge
=
self
.
deep_gemm_expert
te
=
self
.
triton_expert
return
(
dge
is
None
or
dge
.
supports_chunking
())
and
(
te
is
None
or
te
.
supports_chunking
()
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
super
().
__init__
(
experts
=
DeepGemmExperts
(
quant_config
),
fallback_experts
=
TritonExperts
(
quant_config
),
)
def
supports_expert_map
(
self
)
->
bool
:
dge
=
self
.
deep_gemm_expert
te
=
self
.
triton_expert
return
(
dge
is
None
or
dge
.
supports_expert_map
())
and
(
te
is
None
or
te
.
supports_expert_map
()
)
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
dge
=
self
.
deep_gemm_expert
te
=
self
.
triton_expert
dge_war
=
dge
.
finalize_weight_and_reduce_impl
()
if
dge
else
None
te_war
=
te
.
finalize_weight_and_reduce_impl
()
if
te
else
None
is_dge_war
=
dge_war
is
not
None
is_te_war
=
te_war
is
not
None
if
is_dge_war
and
is_te_war
:
assert
dge_war
==
te_war
,
(
"Both implementations should agree on WeightAndReduce impls. "
f
"Got dge_war:
{
dge_war
}
, and te_war:
{
te_war
}
"
)
if
dge_war
is
not
None
:
return
dge_war
assert
te_war
is
not
None
return
te_war
def
workspace_shapes
(
self
,
M
:
int
,
...
...
@@ -95,11 +39,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if
self
.
allow_deep_gemm
and
(
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm_shape
(
M
,
N
,
K
)
):
assert
self
.
deep_gemm_expert
is
not
None
return
self
.
deep_gemm_expert
.
workspace_shapes
(
if
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm_shape
(
M
,
N
,
K
):
return
self
.
experts
.
workspace_shapes
(
M
,
N
,
K
,
...
...
@@ -109,7 +50,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
,
)
else
:
return
self
.
triton
_expert
.
workspace_shapes
(
return
self
.
fallback
_expert
s
.
workspace_shapes
(
M
,
N
,
K
,
...
...
@@ -119,45 +60,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
,
)
def
ap
pl
y
(
def
_select_experts_im
pl
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
use_deep_gemm
=
self
.
allow_deep_gemm
and
(
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
)
)
experts
=
self
.
deep_gemm_expert
if
use_deep_gemm
else
self
.
triton_expert
assert
experts
is
not
None
experts
.
apply
(
output
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_tokens_meta
,
apply_router_weight_on_input
,
)
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
if
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
):
return
self
.
experts
else
:
return
self
.
fallback_experts
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
5dcd7ef1
...
...
@@ -13,10 +13,8 @@ from compressed_tensors.quantization import (
)
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
...
...
@@ -31,6 +29,7 @@ from vllm.model_executor.layers.fused_moe import (
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
int4_w4a16_moe_quant_config
,
int4_w4afp8_moe_quant_config
,
int8_w8a8_moe_quant_config
,
...
...
@@ -46,11 +45,16 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts
,
fused_marlin_moe
,
)
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
Fp8MoeBackend
,
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
select_fp8_moe_backend
,
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
# noqa
WNA16_SUPPORTED_BITS
,
WNA16_SUPPORTED_TYPES_MAP
,
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
flashinfer_trtllm_fp4_moe
,
...
...
@@ -63,8 +67,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
get_flashinfer_moe_backend
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
expert_weight_is_col_major
,
requant_weight_ue8m0_inplac
e
,
process_fp8_input_tensor_strategy_moe
,
process_fp8_weight_tensor_strategy_mo
e
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_moe_marlin_supports_layer
,
...
...
@@ -76,29 +80,17 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
prepare_moe_fp4_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_moe_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
convert_bf16_scales_to_fp8
,
convert_packed_uint4b8_to_signed_int4_inplace
,
swizzle_blockscale
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.deep_gemm
import
(
get_col_major_tma_aligned_tensor
,
get_mk_alignment_for_contiguous_layout
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
)
from
vllm.utils.import_utils
import
has_deep_gemm
logger
=
init_logger
(
__name__
)
...
...
@@ -657,10 +649,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
:
FusedMoEConfig
,
layer_name
:
str
|
None
=
None
,
):
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
,
)
super
().
__init__
(
moe
)
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
...
...
@@ -687,42 +675,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
and
not
self
.
block_quant
)
# Disable marlin for rocm
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
self
.
rocm_aiter_moe_enabled
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
# cutlass path
self
.
is_fp8_w8a8_sm100
=
CompressedTensorsConfig
.
_is_fp8_w8a8_sm100
(
self
.
weight_quant
,
self
.
input_quant
self
.
fp8_backend
=
select_fp8_moe_backend
(
block_quant
=
self
.
block_quant
,
tp_size
=
moe
.
tp_size
,
with_lora_support
=
moe
.
is_lora_enabled
,
# TODO(rob): enable selecting this externally.
allow_vllm_cutlass
=
True
,
)
self
.
use_cutlass
=
not
self
.
block_quant
and
(
CompressedTensorsConfig
.
_is_fp8_w8a8_sm90
(
self
.
weight_quant
,
self
.
input_quant
if
self
.
fp8_backend
!=
Fp8MoeBackend
.
MARLIN
:
per_act_token
=
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
per_channel_quant
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
)
if
per_act_token
!=
per_channel_quant
:
raise
NotImplementedError
(
"For FP8 Fused MoE layers, per-token and per-channel must be "
"used together."
)
# TODO(rob): hook this up in a follow up PR.
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
raise
NotImplementedError
(
"FlashInfer TRTLLM backend not supported for compressed-tensors yet."
)
or
self
.
is_fp8_w8a8_sm100
)
self
.
disable_expert_map
=
False
self
.
layer_name
=
layer_name
self
.
marlin_input_dtype
=
(
get_marlin_input_dtype
(
layer_name
)
if
self
.
use_marlin
else
None
)
self
.
allow_deep_gemm
=
(
self
.
block_quant
and
envs
.
VLLM_MOE_USE_DEEP_GEMM
and
is_deep_gemm_supported
()
and
list
(
self
.
weight_block_size
)
==
get_mk_alignment_for_contiguous_layout
()
)
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
def
create_weights
(
self
,
...
...
@@ -880,163 +857,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
assert
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
if
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
):
logger
.
warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# Allow for accessing weights and scales in standard way.
w13
=
layer
.
w13_weight
w2
=
layer
.
w2_weight
w13_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
w13_input_scale
=
layer
.
w13_input_scale
w2_input_scale
=
layer
.
w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if
current_platform
.
is_fp8_fnuz
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
w13
,
w13_scale
,
w13_input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
w13
,
w13_scale
,
w13_input_scale
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
w2
,
w2_scale
,
w2_input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
w2
,
w2_scale
,
w2_input_scale
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
:
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
local_num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
(
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# Property to determine if AITER is used
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
# Per tensor kernels require single activation scale. Use the max.
if
self
.
static_input_scales
:
assert
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
assert
w13_input_scale
is
not
None
and
w2_input_scale
is
not
None
w13_input_scale
,
w2_input_scale
=
process_fp8_input_tensor_strategy_moe
(
w13_input_scale
,
w2_input_scale
)
replace_parameter
(
layer
,
"w13_input_scale"
,
w13_input_scale
)
replace_parameter
(
layer
,
"w2_input_scale"
,
w2_input_scale
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
elif
self
.
use_marlin
:
(
workspace
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
)
=
prepare_moe_fp8_layer_for_marlin
(
layer
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
input_dtype
=
self
.
marlin_input_dtype
,
# Per-tensor kernels use a single scale, for W13, but on disk there
# is a separate scale for W1 and W3. Requantize with the max scale.
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
:
process_fp8_weight_tensor_strategy_moe
(
w13
,
w13_scale
,
shard_size
=
layer
.
intermediate_size_per_partition
,
num_experts
=
layer
.
num_local_experts
,
)
w13
,
w2
,
w13_scale
,
w2_scale
=
convert_to_fp8_moe_kernel_format
(
fp8_backend
=
self
.
fp8_backend
,
layer
=
layer
,
w13
=
w13
,
w2
=
w2
,
w13_scale
=
w13_scale
,
w2_scale
=
w2_scale
,
w13_input_scale
=
w13_input_scale
,
w2_input_scale
=
w2_input_scale
,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter
(
layer
,
"w13_weight"
,
w13
)
replace_parameter
(
layer
,
"w2_weight"
,
w2
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_scale
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_scale
)
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
:
self
.
kernel
,
self
.
use_inplace
=
make_fp8_moe_kernel
(
layer
=
layer
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
)
layer
.
workspace
=
workspace
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_weight_scale
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_weight_scale
)
if
self
.
use_cutlass
:
assert
self
.
weight_quant
.
strategy
!=
QuantizationStrategy
.
BLOCK
device
=
layer
.
w13_weight
.
device
# ab_strides1 and c_strides2 are the same
self
.
ab_strides1_c_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,),
layer
.
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
ab_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,),
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides1
=
torch
.
full
(
(
layer
.
local_num_experts
,),
2
*
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
,
)
if
is_deep_gemm_e8m0_used
()
and
self
.
block_quant
:
assert
layer
.
weight_block_size
is
not
None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz
=
tuple
(
layer
.
weight_block_size
)
requant_weight_ue8m0_inplace
(
layer
.
w13_weight
.
data
,
layer
.
w13_weight_scale
.
data
,
block_sz
,
)
requant_weight_ue8m0_inplace
(
layer
.
w2_weight
.
data
,
layer
.
w2_weight_scale
.
data
,
block_sz
,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if
expert_weight_is_col_major
(
layer
.
w13_weight_scale
):
layer
.
w13_weight_scale
=
get_col_major_tma_aligned_tensor
(
layer
.
w13_weight_scale
)
if
expert_weight_is_col_major
(
layer
.
w2_weight_scale
):
layer
.
w2_weight_scale
=
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
use_marlin
or
self
.
rocm_aiter_moe_enabled
:
if
self
.
fp8_backend
in
[
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
AITER
]
:
return
None
else
:
return
super
().
maybe_make_prepare_finalize
(
routing_tables
)
...
...
@@ -1048,7 +937,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
)
->
FusedMoEPermuteExpertsUnpermute
:
# cutlass path
assert
self
.
moe_quant_config
is
not
None
if
self
.
use_cutlass
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
VLLM_CUTLASS
:
from
vllm.model_executor.layers.fused_moe
import
(
CutlassBatchedExpertsFp8
,
CutlassExpertsFp8
,
...
...
@@ -1064,26 +953,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
):
logger
.
debug
(
"CutlassBatchedExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
experts
=
CutlassBatchedExpertsFp8
(
self
.
moe
.
num_local_experts
,
num_dispatchers
,
self
.
moe
.
in_dtype
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
max_experts_per_worker
=
self
.
moe
.
num_local_experts
,
num_dispatchers
=
num_dispatchers
,
out_dtype
=
self
.
moe
.
in_dtype
,
e
=
layer
.
local_num_experts
,
n
=
layer
.
intermediate_size_per_partition
,
k
=
layer
.
hidden_size
,
device
=
layer
.
w13_weight
.
device
,
quant_config
=
self
.
moe_quant_config
,
)
else
:
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
experts
=
CutlassExpertsFp8
(
self
.
moe
.
in_dtype
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
out_dtype
=
self
.
moe
.
in_dtype
,
e
=
layer
.
local_num_experts
,
n
=
layer
.
intermediate_size_per_partition
,
k
=
layer
.
hidden_size
,
device
=
layer
.
w13_weight
.
device
,
quant_config
=
self
.
moe_quant_config
,
)
# TODO(rob): investigate disable_expert_map
self
.
disable_expert_map
=
(
num_dispatchers
>
1
or
not
experts
.
supports_expert_map
()
)
...
...
@@ -1096,13 +986,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedTritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
TritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
)
assert
not
self
.
rocm_aiter_moe_enabled
and
not
self
.
use_marlin
use_deep_gemm
=
envs
.
VLLM_USE_DEEP_GEMM
and
envs
.
VLLM_MOE_USE_DEEP_GEMM
assert
self
.
fp8_backend
not
in
[
Fp8MoeBackend
.
AITER
,
Fp8MoeBackend
.
MARLIN
]
if
(
prepare_finalize
.
activation_format
...
...
@@ -1111,28 +1002,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
if
use_deep_gemm
and
not
has_deep_gemm
():
raise
RuntimeError
(
"DeepGEMM requested for MoE layer but not installed."
)
compatible_with_deep_gemm
=
(
self
.
moe_quant_config
.
use_fp8_w8a8
and
self
.
moe_quant_config
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
)
# If this MoE layer is compatible with DeepGEMM, the proper env
# vars are set and DeepGEMM is not installed, throw an error.
if
use_deep_gemm
and
compatible_with_deep_gemm
and
not
has_deep_gemm
():
raise
RuntimeError
(
f
"MoE layer incompatible with DeepGEMM, expected "
f
"fp8==True, got
{
self
.
moe_quant_config
.
use_fp8_w8a8
}
"
f
"or block_shape
{
self
.
moe_quant_config
.
block_shape
}
"
f
"==
{
get_mk_alignment_for_contiguous_layout
()
}
."
)
if
use_deep_gemm
and
compatible_with_deep_gemm
and
has_deep_gemm
():
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
logger
.
debug
(
"BatchedDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
BatchedDeepGemmExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
...
...
@@ -1148,17 +1018,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
)
else
:
logger
.
debug
(
"TritonOrDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
TritonOrDeepGemmExperts
(
self
.
moe_quant_config
,
allow_deep_gemm
=
use_deep_gemm
,
)
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
logger
.
debug
(
"TritonOrDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
TritonOrDeepGemmExperts
(
self
.
moe_quant_config
)
else
:
logger
.
debug
(
"TritonExperts(%s)"
,
self
.
__class__
.
__name__
)
return
TritonExperts
(
self
.
moe_quant_config
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
use_marlin
:
return
None
if
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
return
fp8_w8a16_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
block_shape
=
self
.
weight_block_size
,
)
per_act_token
=
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
...
...
@@ -1184,118 +1059,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
router_logits
=
router_logits
,
)
per_act_token
=
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
if
self
.
use_marlin
:
assert
layer
.
activation
==
"silu"
,
(
f
"
{
layer
.
activation
}
not supported for Marlin MoE."
)
return
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
None
,
None
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
input_dtype
=
self
.
marlin_input_dtype
,
workspace
=
layer
.
workspace
,
)
elif
self
.
rocm_aiter_moe_enabled
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa E501
rocm_aiter_fused_experts
,
)
assert
per_act_token
==
per_channel_quant
assert
self
.
moe_quant_config
is
not
None
return
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
)
# cutlass path
elif
self
.
use_cutlass
:
assert
self
.
moe_quant_config
is
not
None
# small-batch fallback on SM100
if
self
.
is_fp8_w8a8_sm100
and
topk_ids
.
shape
[
0
]
<=
8
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
per_act_token
==
per_channel_quant
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
None
if
self
.
disable_expert_map
else
layer
.
expert_map
,
# ???
quant_config
=
self
.
moe_quant_config
,
allow_deep_gemm
=
self
.
allow_deep_gemm
,
)
else
:
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp8
,
)
assert
per_act_token
==
per_channel_quant
assert
self
.
moe_quant_config
is
not
None
return
cutlass_moe_fp8
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
quant_config
=
self
.
moe_quant_config
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
None
if
self
.
disable_expert_map
else
layer
.
expert_map
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
else
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
self
.
kernel
is
not
None
result
=
self
.
kernel
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
self
.
use_inplace
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map
=
None
if
self
.
disable_expert_map
else
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
assert
per_act_token
==
per_channel_quant
assert
self
.
moe_quant_config
is
not
None
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
allow_deep_gemm
=
self
.
allow_deep_gemm
,
)
return
result
@
property
def
supports_eplb
(
self
)
->
bool
:
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
5dcd7ef1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
...
...
@@ -27,13 +26,17 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
Fp8MoeBackend
,
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
make_fp8_moe_quant_config
,
select_fp8_moe_backend
,
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
...
...
@@ -46,25 +49,20 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
apply_flashinfer_per_tensor_scale_fp8
,
apply_fi_trtllm_fp8_per_tensor_moe
,
build_flashinfer_fp8_cutlass_moe_prepare_finalize
,
get_flashinfer_moe_backend
,
make_fp8_moe_alpha_scales_for_fi
,
register_scales_for_trtllm_fp8_per_tensor_moe
,
rotate_flashinfer_fp8_moe_weights
,
select_cutlass_fp8_gemm_impl
,
swap_w13_to_w31
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
create_fp8_input_scale
,
create_fp8_scale_parameter
,
create_fp8_weight_parameter
,
deepgemm_post_process_fp8_weight_block
,
maybe_post_process_fp8_weight_block
,
process_fp8_input_tensor_strategy_moe
,
process_fp8_weight_block_strategy
,
process_fp8_weight_tensor_strategy
,
process_fp8_weight_tensor_strategy_moe
,
validate_fp8_block_shape
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
...
@@ -73,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_moe_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
...
...
@@ -81,12 +78,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
all_close_1d
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
...
...
@@ -96,11 +91,8 @@ from vllm.model_executor.parameter import (
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
)
from
vllm.utils.flashinfer
import
has_flashinfer_moe
from
vllm.utils.import_utils
import
has_deep_gemm
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
...
...
@@ -110,107 +102,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger
=
init_logger
(
__name__
)
class
Fp8MoeBackend
(
Enum
):
NONE
=
0
FLASHINFER_TRTLLM
=
1
FLASHINFER_CUTLASS
=
2
DEEPGEMM
=
3
MARLIN
=
4
TRITON
=
5
AITER
=
6
def
get_fp8_moe_backend
(
block_quant
:
bool
,
moe_parallel_config
:
FusedMoEParallelConfig
,
with_lora_support
:
bool
,
)
->
Fp8MoeBackend
|
None
:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if
current_platform
.
is_xpu
():
return
None
if
with_lora_support
:
return
Fp8MoeBackend
.
TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if
(
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability_family
(
100
)
or
current_platform
.
is_device_capability
(
90
)
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()
):
backend
=
get_flashinfer_moe_backend
()
if
backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
logger
.
info_once
(
"Using FlashInfer FP8 MoE TRTLLM backend for SM100"
)
return
Fp8MoeBackend
.
FLASHINFER_TRTLLM
else
:
if
block_quant
and
current_platform
.
is_device_capability_family
(
100
):
raise
ValueError
(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger
.
info_once
(
"Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100"
)
return
Fp8MoeBackend
.
FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
if
current_platform
.
is_rocm
():
use_marlin
=
False
if
use_marlin
:
logger
.
info_once
(
"Using Marlin backend for FP8 MoE"
)
return
Fp8MoeBackend
.
MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm
=
envs
.
VLLM_MOE_USE_DEEP_GEMM
if
not
envs
.
is_set
(
"VLLM_MOE_USE_DEEP_GEMM"
)
and
moe_parallel_config
.
tp_size
>=
8
:
moe_use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it."
,
scope
=
"local"
,
)
# Determine if we should use DeepGEMM (top-level enable switch)
# - If explicitly set by user, respect their choice
# - If not platform supports DeepGEMM, disable it
# This helps avoid warning messages on unsupported platforms.
use_deep_gemm
=
envs
.
VLLM_USE_DEEP_GEMM
if
not
is_deep_gemm_supported
():
use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM is disabled because the platform does not support it."
,
scope
=
"local"
,
)
if
use_deep_gemm
and
moe_use_deep_gemm
and
block_quant
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"DeepGEMM backend requested but not available."
,
scope
=
"local"
)
elif
is_deep_gemm_supported
():
logger
.
info_once
(
"Using DeepGEMM backend for FP8 MoE"
,
scope
=
"local"
)
return
Fp8MoeBackend
.
DEEPGEMM
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MOE
:
logger
.
info_once
(
"Using ROCm AITER backend for FP8 MoE"
,
scope
=
"local"
)
return
Fp8MoeBackend
.
AITER
# default to Triton
logger
.
info_once
(
"Using Triton backend for FP8 MoE"
)
return
Fp8MoeBackend
.
TRITON
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
...
...
@@ -348,7 +239,6 @@ class Fp8Config(QuantizationConfig):
moe_quant_method
=
Fp8MoEMethod
(
self
,
layer
)
else
:
moe_quant_method
=
Fp8OnlineMoEMethod
(
self
,
layer
)
moe_quant_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
moe_quant_method
elif
isinstance
(
layer
,
Attention
):
return
Fp8KVCacheMethod
(
self
)
...
...
@@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
Fp8Config
,
layer
:
torch
.
nn
.
Module
):
super
().
__init__
(
layer
.
moe_config
)
self
.
layer
=
layer
self
.
quant_config
=
quant_config
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
:
bool
=
self
.
weight_block_size
is
not
None
self
.
weight_scale_name
=
(
"weight_scale_inv"
if
self
.
block_quant
else
"weight_scale"
)
self
.
fp8_backend
=
get_fp8_moe_backend
(
self
.
block_quant
,
layer
.
moe_parallel_config
,
self
.
moe
.
is_lora_enabled
self
.
fp8_backend
=
select_fp8_moe_backend
(
block_quant
=
self
.
block_quant
,
tp_size
=
layer
.
moe_parallel_config
.
tp_size
,
with_lora_support
=
self
.
moe
.
is_lora_enabled
,
)
self
.
marlin_input_dtype
=
None
self
.
flashinfer_moe_backend
:
FlashinferMoeBackend
|
None
=
None
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
TENSORRT_LLM
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
CUTLASS
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
if
self
.
block_quant
and
self
.
weight_block_size
!=
[
128
,
128
]:
raise
NotImplementedError
(
"FlashInfer CUTLASS FP8 MoE backend only supports block "
"size [128, 128]."
)
if
not
self
.
block_quant
:
if
layer
.
renormalize
or
layer
.
custom_routing_function
is
not
None
:
raise
NotImplementedError
(
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
f
"function or renormalization, but got
{
layer
.
renormalize
}
and "
f
"
{
layer
.
custom_routing_function
}
."
)
if
layer
.
scoring_func
!=
"sigmoid"
:
raise
NotImplementedError
(
"FlashInfer CUTLASS FP8 MoE backend only supports "
f
"'sigmoid' scoring function, but got
{
layer
.
scoring_func
}
."
)
if
layer
.
activation
!=
"silu"
:
raise
NotImplementedError
(
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
...
...
@@ -778,12 +652,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dynamic_per_token
=
(
not
self
.
block_quant
and
self
.
quant_config
.
activation_scheme
!=
"static"
)
if
self
.
flashinfer_moe_backend
is
not
None
and
dynamic_per_token
:
if
dynamic_per_token
and
self
.
fp8_backend
in
[
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
]:
raise
NotImplementedError
(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
def
create_weights
(
self
,
layer
:
Module
,
...
...
@@ -907,148 +786,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
_
convert_weights_to_kernel_format
(
def
_
setup_kernel
(
self
,
layer
:
Module
,
w13
_weight
:
torch
.
Tensor
,
w2
_weight
:
torch
.
Tensor
,
w13_
weight_
scale
:
torch
.
Tensor
,
w2_
weight_
scale
:
torch
.
Tensor
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
|
None
,
w2_input_scale
:
torch
.
Tensor
|
None
,
)
->
None
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
assert
self
.
block_quant
w13_weight
,
w13_weight_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
w13_weight
,
ws
=
w13_weight_scale
,
quant_block_shape
=
tuple
(
layer
.
weight_block_size
),
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
w2_weight
,
w2_weight_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
w2_weight
,
ws
=
w2_weight_scale
,
quant_block_shape
=
tuple
(
layer
.
weight_block_size
),
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
AITER
:
w13_weight
,
w2_weight
=
rocm_aiter_ops
.
shuffle_weights
(
w13_weight
,
w2_weight
)
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
(
workspace
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
)
=
prepare_moe_fp8_layer_for_marlin
(
layer
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
input_dtype
=
self
.
marlin_input_dtype
,
)
layer
.
workspace
=
workspace
elif
self
.
fp8_backend
in
[
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
]:
w13_weight
=
swap_w13_to_w31
(
w13_weight
)
if
self
.
block_quant
:
w13_weight_scale
=
swap_w13_to_w31
(
w13_weight_scale
)
else
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
rotate_flashinfer_fp8_moe_weights
(
w13_weight
,
w2_weight
)
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
=
layer
,
w13_weight_scale
=
w13_weight
,
w13_input_scale
=
w13_input_scale
,
w2_weight_scale
=
w2_weight
,
w2_input_scale
=
w2_input_scale
,
)
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
AITER
:
w13_weight
,
w2_weight
=
rocm_aiter_ops
.
shuffle_weights
(
w13_weight
,
w2_weight
)
# Shuffle weights to runtime format.
w13
,
w2
,
w13_scale
,
w2_scale
=
convert_to_fp8_moe_kernel_format
(
fp8_backend
=
self
.
fp8_backend
,
layer
=
layer
,
w13
=
w13
,
w2
=
w2
,
w13_scale
=
w13_scale
,
w2_scale
=
w2_scale
,
w13_input_scale
=
w13_input_scale
,
w2_input_scale
=
w2_input_scale
,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
replace_parameter
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
,
w13_weight_scale
)
replace_parameter
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
,
w2_weight_scale
)
def
_setup_kernel
(
self
,
layer
:
Module
)
->
None
:
"""Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
AiterExperts
,
)
# Flashinfer TRTLLM does not use the modular kernel abstraction.
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
replace_parameter
(
layer
,
"w13_weight"
,
w13
)
replace_parameter
(
layer
,
"w2_weight"
,
w2
)
replace_parameter
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
,
w13_scale
)
replace_parameter
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
,
w2_scale
)
# Setup modular kernel for TP case.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
assert
self
.
moe_quant_config
is
not
None
self
.
use_inplace
=
True
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
# TODO: make defer_input_quant an attr of the FlashInferExperts
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
self
.
block_quant
),
FlashInferExperts
(
out_dtype
=
layer
.
orig_dtype
,
quant_config
=
self
.
moe_quant_config
,
ep_rank
=
self
.
moe
.
ep_rank
,
ep_size
=
self
.
moe
.
ep_size
,
tp_rank
=
self
.
moe
.
tp_rank
,
tp_size
=
self
.
moe
.
tp_size
,
use_dp
=
(
self
.
moe
.
dp_size
>
1
),
use_deepseek_fp8_block_scale
=
self
.
block_quant
,
),
)
self
.
use_inplace
=
False
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
AITER
:
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
AiterExperts
(
quant_config
=
self
.
moe_quant_config
),
)
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MarlinExperts
(
quant_config
=
self
.
moe_quant_config
),
)
else
:
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonOrDeepGemmExperts
(
quant_config
=
self
.
moe_quant_config
,
allow_deep_gemm
=
(
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
),
),
if
self
.
moe_quant_config
:
self
.
kernel
,
self
.
use_inplace
=
make_fp8_moe_kernel
(
layer
=
layer
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
...
...
@@ -1056,78 +830,58 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return
# Allow for accessing weights and scales in standard way.
w13
_weight
=
layer
.
w13_weight
w2
_weight
=
layer
.
w2_weight
w13_
weight_
scale
=
getattr
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
)
w2_
weight_
scale
=
getattr
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
)
w13
=
layer
.
w13_weight
w2
=
layer
.
w2_weight
w13_scale
=
getattr
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
)
w2_scale
=
getattr
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
)
w13_input_scale
=
layer
.
w13_input_scale
w2_input_scale
=
layer
.
w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if
current_platform
.
is_fp8_fnuz
():
w13
_weight
,
w13_weight
_scale
,
w13_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
w13_weight
,
w13_weight_scale
,
w13_input_scale
)
w13
,
w13
_scale
,
w13_input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
w13
,
w13_scale
,
w13_input_scale
,
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
w2_weight
,
w2_weight_scale
,
w2_input_scale
w2
,
w2_scale
,
w2_input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
w2
,
w2_scale
,
w2_input_scale
,
)
# Per tensor kernels require single activation scale. Use the max.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
assert
not
self
.
block_quant
assert
w13_input_scale
is
not
None
and
w2_input_scale
is
not
None
if
not
all_close_1d
(
w13_input_scale
)
or
not
all_close_1d
(
w2_input_scale
):
logger
.
warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
replace_parameter
(
layer
,
"w13_input_scale"
,
w13_input_scale
.
max
())
replace_parameter
(
layer
,
"w2_input_scale"
,
w2_input_scale
.
max
())
w13_input_scale
,
w2_input_scale
=
process_fp8_input_tensor_strategy_moe
(
w13_input_scale
,
w2_input_scale
)
replace_parameter
(
layer
,
"w13_input_scale"
,
w13_input_scale
)
replace_parameter
(
layer
,
"w2_input_scale"
,
w2_input_scale
)
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
if
not
self
.
block_quant
:
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
local_num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
w13_weight_scale
[
expert_id
][
shard_id
],
)
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
(
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
w13_weight_scale
=
max_w13_scales
w13
,
w13_scale
=
process_fp8_weight_tensor_strategy_moe
(
w13
,
w13_scale
,
shard_size
,
layer
.
local_num_experts
)
# Shuffle weights into the runtime format.
self
.
_convert_weights_to_kernel_format
(
layer
=
layer
,
w13_weight
=
w13_weight
,
w2_weight
=
w2_weight
,
w13_weight_scale
=
w13_weight_scale
,
w2_weight_scale
=
w2_weight_scale
,
w13_input_scale
=
w13_input_scale
,
w2_input_scale
=
w2_input_scale
,
# Shuffle weights to runtime format and setup kernel.
self
.
_setup_kernel
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
w13_input_scale
,
w2_input_scale
)
# Setup modular kernel for TP case.
self
.
_setup_kernel
(
layer
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
(
self
.
fp8_backend
==
Fp8MoeBackend
.
AITER
or
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSOR
RT
_
LLM
)
:
if
self
.
fp8_backend
in
[
Fp8MoeBackend
.
AITER
,
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
FLASHINFER_T
RTLLM
,
]
:
return
None
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
prepare_finalize
=
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
...
...
@@ -1184,7 +938,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
elif
self
.
moe
.
is_lora_enabled
:
return
TritonExperts
(
quant_config
=
self
.
moe_quant_config
)
elif
self
.
f
lashinfer_moe
_backend
==
F
lashinfer
MoeBackend
.
CUTLASS
:
elif
self
.
f
p8
_backend
==
F
p8
MoeBackend
.
FLASHINFER_
CUTLASS
:
# Select GEMM experts with block-scale when weights are block-quantized
experts
=
select_cutlass_fp8_gemm_impl
(
self
.
moe
,
...
...
@@ -1193,17 +947,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
logger
.
debug_once
(
"Using %s"
,
experts
.
__class__
.
__name__
)
return
experts
el
se
:
el
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
logger
.
debug
(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
weight_block_size
,
False
,
)
return
TritonOrDeepGemmExperts
(
quant_config
=
self
.
moe_quant_config
,
allow_deep_gemm
=
(
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
),
return
TritonOrDeepGemmExperts
(
self
.
moe_quant_config
)
else
:
assert
self
.
fp8_backend
==
Fp8MoeBackend
.
TRITON
logger
.
debug
(
"TritonExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
weight_block_size
,
False
,
)
return
TritonExperts
(
self
.
moe_quant_config
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
...
...
@@ -1212,42 +972,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
# MARLIN uses mixed precision W8A16 config.
if
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
return
fp8_w8a16_moe_quant_config
(
w1_scale
=
getattr
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
),
w2_scale
=
getattr
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
),
block_shape
=
self
.
weight_block_size
,
)
w1_scale
=
getattr
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
)
w2_scale
=
getattr
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
)
a1_scale
=
layer
.
w13_input_scale
a2_scale
=
layer
.
w2_input_scale
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if
(
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
and
not
self
.
block_quant
):
g1_alphas
,
g2_alphas
=
make_fp8_moe_alpha_scales_for_fi
(
w1_scale
,
a1_scale
,
w2_scale
,
a2_scale
,
)
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
(
1.0
/
a2_scale
),
g1_alphas
=
g1_alphas
,
g2_alphas
=
g2_alphas
,
)
# All other backends use normal config.
return
fp8_w8a8_moe_quant_config
(
return
make_fp8_moe_quant_config
(
fp8_backend
=
self
.
fp8_backend
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
...
...
@@ -1269,7 +1000,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
f
lashinfer_moe
_backend
==
F
lashinfer
MoeBackend
.
TENSOR
RT
_
LLM
:
if
self
.
f
p8
_backend
==
F
p8
MoeBackend
.
FLASHINFER_T
RTLLM
:
# TODO(rob): convert this to MK.
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `Fp8MoEMethod` yet."
)
...
...
@@ -1308,10 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling
=
layer
.
routed_scaling_factor
,
)
else
:
assert
(
not
layer
.
renormalize
and
layer
.
custom_routing_function
is
not
None
)
result
=
apply_flashinfer_per_tensor_scale_fp8
(
result
=
apply_fi_trtllm_fp8_per_tensor_moe
(
layer
=
layer
,
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -1327,6 +1055,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states
=
x
,
router_logits
=
router_logits
,
)
assert
self
.
kernel
is
not
None
result
=
self
.
kernel
(
x
,
layer
.
w13_weight
,
...
...
@@ -1358,7 +1088,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
assert
not
quant_config
.
is_checkpoint_fp8_serialized
assert
quant_config
.
activation_scheme
==
"dynamic"
assert
quant_config
.
weight_block_size
is
None
assert
self
.
flashinfer_moe_backend
is
None
def
create_weights
(
self
,
...
...
@@ -1447,6 +1176,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
...
...
@@ -1457,33 +1188,30 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
for
expert
in
range
(
layer
.
local_num_experts
):
w13
_weight
[
expert
,
:,
:],
layer
.
w13_weight
_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
[
expert
,
:,
:]
)
w13
[
expert
,
:,
:],
w13
_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
[
expert
,
:,
:]
)
w2
_weight
[
expert
,
:,
:],
layer
.
w2_weight
_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
[
expert
,
:,
:]
)
w2
[
expert
,
:,
:],
w2
_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
[
expert
,
:,
:]
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
# Shuffle weights
into the
runtime format.
self
.
_
convert_weights_to_kernel_format
(
layer
=
layer
,
w13
_weight
=
w13_weight
,
w2
_weight
=
w2_weight
,
w13_
weight_scale
=
layer
.
w13_weight_
scale
,
w2_
weight_scale
=
layer
.
w2_weight_
scale
,
w13_input_scale
=
None
,
w2_input_scale
=
None
,
# Shuffle weights
to
runtime format
and setup kernel
.
self
.
_
setup_kernel
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
)
# Setup modular kernel for TP case.
self
.
_setup_kernel
(
layer
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
5dcd7ef1
...
...
@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
nvfp4_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
...
...
@@ -24,6 +23,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
Fp8MoeBackend
,
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
make_fp8_moe_quant_config
,
select_fp8_moe_backend
,
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
...
...
@@ -45,19 +51,16 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
apply_f
lashinfer
_per_tensor_
scale_fp8
,
apply_f
i_trtllm_fp8
_per_tensor_
moe
,
build_flashinfer_fp8_cutlass_moe_prepare_finalize
,
flashinfer_cutlass_moe_fp8
,
get_flashinfer_moe_backend
,
is_flashinfer_supporting_global_sf
,
make_fp8_moe_alpha_scales_for_fi
,
register_scales_for_trtllm_fp8_per_tensor_moe
,
rotate_flashinfer_fp8_moe_weights
,
select_cutlass_fp8_gemm_impl
,
swap_w13_to_w31
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
process_fp8_input_tensor_strategy_moe
,
process_fp8_weight_tensor_strategy_moe
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
get_marlin_input_dtype
,
...
...
@@ -85,13 +88,12 @@ from vllm.model_executor.parameter import (
ModelWeightParameter
,
PerTensorScaleParameter
,
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.flashinfer
import
(
flashinfer_scaled_fp4_mm
,
has_flashinfer
,
has_flashinfer_moe
,
)
from
vllm.utils.math_utils
import
round_up
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
...
...
@@ -721,38 +723,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
:
FusedMoE
,
)
->
None
:
super
().
__init__
(
layer
.
moe_config
)
self
.
layer
=
layer
self
.
quant_config
=
quant_config
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_fp8_supported
,
assert
self
.
quant_config
.
is_checkpoint_fp8_serialized
self
.
fp8_backend
=
select_fp8_moe_backend
(
block_quant
=
False
,
tp_size
=
layer
.
moe_parallel_config
.
tp_size
,
with_lora_support
=
self
.
moe
.
is_lora_enabled
,
)
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
flashinfer_moe_backend
:
FlashinferMoeBackend
|
None
=
None
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
():
self
.
flashinfer_moe_backend
=
get_flashinfer_moe_backend
()
if
(
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
and
not
self
.
moe
.
is_act_and_mul
):
logger
.
info_once
(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
CUTLASS
logger
.
info_once
(
f
"Using FlashInfer
{
self
.
flashinfer_moe_backend
.
value
}
kernels"
)
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
# TRT LLM not supported with all2all yet.
if
self
.
f
lashinfer_moe
_backend
==
F
lashinfer
MoeBackend
.
TENSOR
RT
_
LLM
:
if
self
.
f
p8
_backend
==
F
p8
MoeBackend
.
FLASHINFER_T
RTLLM
:
return
None
elif
self
.
f
lashinfer_moe
_backend
==
F
lashinfer
MoeBackend
.
CUTLASS
:
elif
self
.
f
p8
_backend
==
F
p8
MoeBackend
.
FLASHINFER_
CUTLASS
:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if
self
.
moe
.
dp_size
==
1
:
return
None
...
...
@@ -787,6 +774,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
orig_dtype
=
params_dtype
layer
.
num_experts
=
num_experts
# Use FP8 dtype if checkpoint is serialized
weight_dtype
=
(
torch
.
float8_e4m3fn
...
...
@@ -826,217 +816,121 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
if
self
.
moe
.
is_act_and_mul
:
w13_weight_scale_shape
=
(
num_experts
,
2
)
else
:
w13_weight_scale_shape
=
(
num_experts
,
1
)
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
w13_weight_scale_shape
,
1.0
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Set weight loader attributes for scales
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
if
self
.
flashinfer_moe_backend
is
not
None
:
self
.
_maybe_pad_intermediate_for_flashinfer
(
layer
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,
2
if
self
.
moe
.
is_act_and_mul
else
1
),
1.0
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
per_tensor_dequantize
,
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
# Handle scale parameters
if
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if
(
layer
.
w13_weight_scale
.
dim
()
==
2
and
layer
.
w13_weight_scale
.
shape
[
1
]
==
2
):
assert
self
.
moe
.
is_act_and_mul
,
(
"w13_weight_scale should have 2 elements per expert "
"only for gated MoE"
)
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
# Requantize using the combined max scale
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
# Update the scale parameter to be per-expert
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
def
_setup_kernel
(
self
,
layer
:
torch
.
nn
.
Module
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
,
):
w13
,
w2
,
w13_scale
,
w2_scale
=
convert_to_fp8_moe_kernel_format
(
fp8_backend
=
self
.
fp8_backend
,
layer
=
layer
,
w13
=
w13
,
w2
=
w2
,
w13_scale
=
w13_scale
,
w2_scale
=
w2_scale
,
w13_input_scale
=
w13_input_scale
,
w2_input_scale
=
w2_input_scale
,
)
if
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
# Input scales must be equal for each expert in fp8 MoE layers.
if
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter
(
layer
,
"w13_weight"
,
w13
)
replace_parameter
(
layer
,
"w2_weight"
,
w2
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_scale
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_scale
)
# Setup modular kernel for TP case.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
:
self
.
kernel
,
self
.
use_inplace
=
make_fp8_moe_kernel
(
layer
=
layer
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
)
if
self
.
flashinfer_moe_backend
is
not
None
:
if
self
.
moe
.
is_act_and_mul
:
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
# NOTE: this adds some attributes used by the trtllm kernel,
# which does not conform to the modular kernels abstraction (yet).
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
=
layer
,
w13_weight_scale
=
layer
.
w13_weight_scale
,
w13_input_scale
=
layer
.
w13_input_scale
,
w2_weight_scale
=
layer
.
w2_weight_scale
,
w2_input_scale
=
layer
.
w2_input_scale
,
)
def
_maybe_pad_intermediate_for_flashinfer
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
if
not
hasattr
(
layer
,
"w13_weight"
)
or
not
hasattr
(
layer
,
"w2_weight"
):
return
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts
,
hidden_size
,
intermediate
=
layer
.
w2_weight
.
shape
min_alignment
=
16
padded_intermediate
=
round_up
(
intermediate
,
min_alignment
)
if
padded_intermediate
==
intermediate
:
return
logger
.
info
(
"Padding intermediate size from %d to %d for up/down projection weights."
,
intermediate
,
padded_intermediate
,
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w13
=
layer
.
w13_weight
w2
=
layer
.
w2_weight
w13_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
w13_input_scale
=
layer
.
w13_input_scale
w2_input_scale
=
layer
.
w2_input_scale
# Per tensor kernels require single activation scale. Use the max.
w13_input_scale
,
w2_input_scale
=
process_fp8_input_tensor_strategy_moe
(
w13_input_scale
,
w2_input_scale
)
replace_parameter
(
layer
,
"w13_input_scale"
,
w13_input_scale
)
replace_parameter
(
layer
,
"w2_input_scale"
,
w2_input_scale
)
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
shard_size
=
layer
.
intermediate_size_per_partition
w13
,
w13_scale
=
process_fp8_weight_tensor_strategy_moe
(
w13
,
w13_scale
,
shard_size
,
num_experts
=
layer
.
w13_weight
.
shape
[
0
],
is_act_and_mul
=
self
.
moe
.
is_act_and_mul
,
)
up_mult
=
2
if
self
.
moe
.
is_act_and_mul
else
1
padded_gate_up_dim
=
up_mult
*
padded_intermediate
# Pad w13 and w12 along its intermediate dimension.
w13
=
layer
.
w13_weight
.
data
padded_w13
=
w13
.
new_zeros
((
num_experts
,
padded_gate_up_dim
,
hidden_size
))
padded_w13
[:,
:
w13
.
shape
[
1
],
:]
=
w13
layer
.
w13_weight
.
data
=
padded_w13
w2
=
layer
.
w2_weight
.
data
padded_w2
=
w2
.
new_zeros
((
num_experts
,
hidden_size
,
padded_intermediate
))
padded_w2
[:,
:,
:
intermediate
]
=
w2
layer
.
w2_weight
.
data
=
padded_w2
if
hasattr
(
layer
,
"intermediate_size_per_partition"
):
layer
.
intermediate_size_per_partition
=
padded_intermediate
# Shuffle weights to runtime format and setup kernel.
self
.
_setup_kernel
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
w13_input_scale
,
w2_input_scale
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
# TRTLLM does not use modular kernels
return
None
elif
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
:
g1_alphas
,
g2_alphas
=
make_fp8_moe_alpha_scales_for_fi
(
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
,
)
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a1_gscale
=
(
1.0
/
layer
.
w13_input_scale
),
a2_gscale
=
(
1.0
/
layer
.
w2_input_scale
),
g1_alphas
=
g1_alphas
,
g2_alphas
=
g2_alphas
,
)
else
:
assert
self
.
flashinfer_moe_backend
is
None
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
w1_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
a1_scale
=
layer
.
w13_input_scale
a2_scale
=
layer
.
w2_input_scale
return
make_fp8_moe_quant_config
(
fp8_backend
=
self
.
fp8_backend
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
def
apply
(
self
,
...
...
@@ -1044,17 +938,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
f
lashinfer_moe
_backend
==
F
lashinfer
MoeBackend
.
TENSOR
RT
_
LLM
:
if
self
.
f
p8
_backend
==
F
p8
MoeBackend
.
FLASHINFER_T
RTLLM
:
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for
`ModelOptFp8MoEMethod` yet
."
"EPLB not supported for
FlashInfer TRTLLM FP8 MoE Backend
."
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
assert
layer
.
activation
==
"silu"
,
(
f
"Expected 'silu' activation but got
{
layer
.
activation
}
"
)
assert
not
layer
.
renormalize
return
apply_f
lashinfer
_per_tensor_
scale_fp8
(
return
apply_f
i_trtllm_fp8
_per_tensor_
moe
(
layer
=
layer
,
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -1066,46 +961,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
# Expert selection
topk_weights
,
topk_ids
=
layer
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
:
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
assert
layer
.
activation
in
(
"silu"
,
"relu2_no_mul"
),
(
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f
"but got
{
layer
.
activation
}
"
)
return
flashinfer_cutlass_moe_fp8
(
x
,
layer
,
topk_weights
,
topk_ids
,
inplace
=
False
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
assert
self
.
moe_quant_config
is
not
None
assert
self
.
kernel
is
not
None
result
=
self
.
kernel
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
self
.
use_inplace
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
quant_config
=
self
.
moe_quant_config
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
return
result
ModelOptFp8Config
.
LinearMethodCls
=
ModelOptFp8LinearMethod
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
5dcd7ef1
...
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_
moe_fp8
_layer_for_marlin
,
prepare_
fp8_moe
_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
(
OCP_MX_BLOCK_SIZE
,
...
...
@@ -315,8 +315,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
elif
self
.
use_marlin
:
(
workspace
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
)
=
(
prepare_
moe_fp8
_layer_for_marlin
(
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
=
(
prepare_
fp8_moe
_layer_for_marlin
(
layer
,
layer
.
w13_weight
,
layer
.
w2_weight
,
...
...
@@ -324,7 +324,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w2_weight_scale
,
)
)
layer
.
workspace
=
workspace
# TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL.
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
5dcd7ef1
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
...
...
@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
)
def
rotate_
flashinfer_fp8_moe_weights
(
def
rotate_
weights_for_fi_trtllm_fp8_per_tensor_moe
(
gemm1_weights
:
torch
.
Tensor
,
gemm2_weights
:
torch
.
Tensor
):
"""Shuffle weights for for FI TRT-LLM Format"""
from
flashinfer
import
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
epilogue_tile_m
=
128
...
...
@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
def
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
:
torch
.
nn
.
Module
,
w13_
weight_
scale
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
,
w2_
weight_
scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
,
)
->
None
:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas
,
g2_alphas
=
make_fp8_moe_alpha_scales_for_fi
(
w13_scale
=
w13_
weight_
scale
,
w13_scale
=
w13_scale
,
w13_input_scale
=
w13_input_scale
,
w2_scale
=
w2_
weight_
scale
,
w2_scale
=
w2_scale
,
w2_input_scale
=
w2_input_scale
,
)
layer
.
w2_input_scale_inv
=
1.0
/
w2_input_scale
...
...
@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
layer
.
output2_scales_scalar
=
g2_alphas
def
apply_f
lashinfer
_per_tensor_
scale_fp8
(
def
apply_f
i_trtllm_fp8
_per_tensor_
moe
(
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
from
vllm.model_executor.models.llama4
import
Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert
(
hasattr
(
layer
,
"output1_scales_scalar"
)
and
hasattr
(
layer
,
"output1_scales_gate_scalar"
)
and
hasattr
(
layer
,
"output2_scales_scalar"
)
)
assert
layer
.
custom_routing_function
==
Llama4MoE
.
custom_routing_function
,
(
"FusedMoE flashinfer kernels are only supported for Llama4"
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert
(
hasattr
(
layer
,
"output1_scales_scalar"
)
and
hasattr
(
layer
,
"output1_scales_gate_scalar"
)
and
hasattr
(
layer
,
"output2_scales_scalar"
)
)
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_per_tensor_scale_fp8
(
is_llama4
=
layer
.
custom_routing_function
==
Llama4MoE
.
custom_routing_function
assert
is_llama4
,
"FusedMoE flashinfer kernels are only supported for Llama4"
return
torch
.
ops
.
vllm
.
fi_trtllm_fp8_per_tensor_moe
(
routing_logits
=
router_logits
,
routing_bias
=
routing_bias
,
hidden_states
=
hidden_states
,
...
...
@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
)
def
flashinfer_cutlass_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_deepseek_fp8_block_scale
:
bool
=
False
,
moe
:
FusedMoEConfig
|
None
=
None
,
)
->
torch
.
Tensor
:
quant_config
=
layer
.
quant_method
.
get_fused_moe_quant_config
(
layer
)
assert
quant_config
is
not
None
# Construct modular kernel with block-scale support when requested.
fused_experts
=
mk
.
FusedMoEModularKernel
(
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
moe
=
moe
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
),
select_cutlass_fp8_gemm_impl
(
moe
=
moe
,
quant_config
=
quant_config
,
out_dtype
=
hidden_states
.
dtype
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
),
moe_parallel_config
=
layer
.
moe_parallel_config
,
)
return
fused_experts
(
hidden_states
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
def
get_flashinfer_moe_backend
()
->
FlashinferMoeBackend
:
backend_map
=
{
"throughput"
:
FlashinferMoeBackend
.
CUTLASS
,
...
...
@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
FlashinferMoeBackend
.
TENSORRT_LLM
,
)
return
backend
in
backends_supporting_global_sf
def
align_fp8_moe_weights_for_fi
(
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
is_act_and_mul
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts
,
hidden_size
,
intermediate
=
w2
.
shape
min_alignment
=
16
padded_intermediate
=
round_up
(
intermediate
,
min_alignment
)
if
padded_intermediate
==
intermediate
:
return
w13
,
w2
,
intermediate
logger
.
info_once
(
"Padding intermediate size from %d to %d for up/down projection weights."
,
intermediate
,
padded_intermediate
,
scope
=
"local"
,
)
up_mult
=
2
if
is_act_and_mul
else
1
padded_gate_up_dim
=
up_mult
*
padded_intermediate
# Pad w13 and w2 along its intermediate dimension.
padded_w13
=
w13
.
new_zeros
((
num_experts
,
padded_gate_up_dim
,
hidden_size
))
padded_w13
[:,
:
w13
.
shape
[
1
],
:]
=
w13
padded_w2
=
w2
.
new_zeros
((
num_experts
,
hidden_size
,
padded_intermediate
))
padded_w2
[:,
:,
:
intermediate
]
=
w2
return
padded_w13
,
padded_w2
,
padded_intermediate
def
prepare_fp8_moe_layer_for_fi
(
layer
:
torch
.
nn
.
Module
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w13_input_scale
:
torch
.
Tensor
|
None
,
w2_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
|
None
,
is_trtllm
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
Note that for trtllm we update the model state dict
with the scale format needed for these kernels.
Note that for per-tensor, we update the layer's
intermediate size if the weights needed padding.
"""
assert
hasattr
(
layer
.
moe_config
,
"is_act_and_mul"
)
block_quant
=
(
hasattr
(
layer
,
"weight_block_size"
)
and
layer
.
weight_block_size
is
not
None
)
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
if
not
block_quant
:
w13
,
w2
,
new_intermediate
=
align_fp8_moe_weights_for_fi
(
w13
,
w2
,
layer
.
moe_config
.
is_act_and_mul
,
)
layer
.
intermediate_size_per_partition
=
new_intermediate
# FI kernels require W31 layout rather than W13.
if
layer
.
moe_config
.
is_act_and_mul
:
w13
=
swap_w13_to_w31
(
w13
)
if
block_quant
:
w13_scale
=
swap_w13_to_w31
(
w13_scale
)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
if
is_trtllm
and
not
block_quant
:
assert
w13_input_scale
is
not
None
assert
w2_input_scale
is
not
None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe
(
w13
,
w2
)
register_scales_for_trtllm_fp8_per_tensor_moe
(
layer
,
w13_scale
=
w13_scale
,
w13_input_scale
=
w13_input_scale
,
w2_scale
=
w2_scale
,
w2_input_scale
=
w2_input_scale
,
)
return
w13
,
w2
,
w13_scale
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
5dcd7ef1
...
...
@@ -21,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
all_close_1d
,
per_tensor_dequantize
,
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
...
...
@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block(
return
wq
,
dg_ws
def
prepare_fp8_moe_layer_for_deepgemm
(
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
block_shape
:
tuple
[
int
],
):
w13
,
w13_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
w13
,
ws
=
w13_scale
,
quant_block_shape
=
block_shape
,
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
w2
,
w2_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
w2
,
ws
=
w2_scale
,
quant_block_shape
=
block_shape
,
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
return
w13
,
w2
,
w13_scale
,
w2_scale
def
_maybe_pad_fp8_weight
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
...
...
@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
replace_parameter
(
layer
,
scale_attr
,
dg_weight_scale
)
def
expert_weight_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
assert
x
.
dim
()
==
3
b
,
m
,
n
=
x
.
shape
return
x
.
stride
(
0
)
==
m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
m
def
process_fp8_weight_tensor_strategy_moe
(
weight
:
torch
.
Tensor
,
weight_scales
:
torch
.
Tensor
,
shard_size
:
int
,
num_experts
:
int
,
is_act_and_mul
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Process moe weights for tensor-wise quantization strategy."""
max_scales
=
weight_scales
.
max
(
dim
=
1
).
values
# For w1 case (i.e. not w13): just collapse the last dim since
# there is already just one scale per expert in this case.
if
not
is_act_and_mul
:
assert
weight_scales
.
shape
[
1
]
==
1
return
weight
,
weight_scales
.
max
()
# For w13 case (common): require single scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
for
expert_id
in
range
(
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
weight_scales
[
expert_id
][
shard_id
],
)
weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_scales
[
expert_id
]
)
start
+=
shard_size
return
weight
,
max_scales
def
process_fp8_input_tensor_strategy_moe
(
w13_input_scale
:
torch
.
Tensor
,
w2_input_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Process moe input scales for tensor-wise quantization strategy."""
if
not
all_close_1d
(
w13_input_scale
)
or
not
all_close_1d
(
w2_input_scale
):
logger
.
info_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
return
w13_input_scale
.
max
(),
w2_input_scale
.
max
()
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
5dcd7ef1
...
...
@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8:
return
_quant_fp8_method
def
get_marlin_input_dtype
(
prefix
):
def
get_marlin_input_dtype
(
prefix
:
str
|
None
=
None
):
if
envs
.
VLLM_MARLIN_INPUT_DTYPE
is
None
:
return
elif
envs
.
VLLM_MARLIN_INPUT_DTYPE
.
lower
()
==
"int8"
:
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
5dcd7ef1
...
...
@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
USE_FP32_REDUCE_DEFAULT
,
get_marlin_input_dtype
,
marlin_make_workspace_new
,
marlin_permute_bias
,
marlin_permute_scales
,
...
...
@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin(
replace_parameter
(
layer
,
"bias"
,
bias
)
def
prepare_
moe_fp8
_layer_for_marlin
(
def
prepare_
fp8_moe
_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
w13_weight
:
torch
.
Tensor
,
w2_weight
:
torch
.
Tensor
,
w13_weight_scale
:
torch
.
Tensor
,
w2_weight_scale
:
torch
.
Tensor
,
input_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
# workspace
torch
.
Tensor
,
# w13_weight
torch
.
Tensor
,
# w2_weight
torch
.
Tensor
,
# w13_weight_scale
torch
.
Tensor
,
# w2_weight_scale
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Shuffle weights and scales into marlin format.
Note that this function has the side effect of adding a `workspace`
attribute to the layer. This `workspace` does not need to be
registered as a Parameter as it is not used during weight reloading.
"""
logger
.
warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
input_dtype
=
get_marlin_input_dtype
()
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
raise
NotImplementedError
(
"Marlin W8A8 is not supported."
)
...
...
@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE
device
=
layer
.
w13_weight
.
device
workspace
=
marlin_make_workspace_new
(
device
,
4
)
# NOTE(rob): we do not need to register the workspace as a param
# because it is not used as part of the weight reloading process.
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
# WEIGHT
...
...
@@ -310,13 +315,7 @@ def prepare_moe_fp8_layer_for_marlin(
w13_weight_scale
=
permute_scales
(
w13_weight_scale
,
"w13"
)
w2_weight_scale
=
permute_scales
(
w2_weight_scale
,
"w2"
)
return
(
workspace
,
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
,
)
return
w13_weight
,
w2_weight
,
w13_weight_scale
,
w2_weight_scale
def
pack_fp8_to_int32
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment