Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3a9752b
Unverified
Commit
c3a9752b
authored
Jan 30, 2026
by
Pavani Majety
Committed by
GitHub
Jan 30, 2026
Browse files
[Hardware][SM100] Add TRTLLM Kernel for INT4 W4A16 Kernel. (#32437)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
f451b455
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
727 additions
and
23 deletions
+727
-23
tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py
tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py
+272
-0
vllm/envs.py
vllm/envs.py
+8
-3
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+170
-13
vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py
...ecutor/layers/quantization/utils/flashinfer_mxint4_moe.py
+266
-0
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+2
-2
No files found.
tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py
0 → 100644
View file @
c3a9752b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test comparing Marlin INT4 MoE vs FlashInfer TRT-LLM MXINT4 MoE."""
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
)
from
vllm.model_executor.layers.fused_moe.router.grouped_topk_router
import
(
grouped_topk
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe
import
(
prepare_static_weights_for_trtllm_mxint4_moe
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
def
mxint4_quantize
(
x
:
torch
.
Tensor
,
sf_vec_size
:
int
=
32
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize BF16 tensor to MXINT4 with block scaling (group_size=sf_vec_size).
Returns:
- uint8 packed (2 INT4/byte): [..., k//2] - stores SIGNED INT4 [-8, 7]
- scales in BF16: [..., k//sf_vec_size]
"""
x_reshaped
=
x
.
reshape
(
-
1
,
sf_vec_size
)
x_max
=
x_reshaped
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
].
to
(
torch
.
float32
)
x_min
=
x_reshaped
.
min
(
dim
=-
1
,
keepdim
=
True
)[
0
].
to
(
torch
.
float32
)
x_max
=
x_max
*
8.0
/
7.0
amax
=
torch
.
where
(
x_max
>
-
x_min
,
x_max
,
-
x_min
)
scales
=
amax
/
8.0
x_scaled
=
x_reshaped
*
scales
.
reciprocal
()
x_int8
=
(
x_scaled
.
round
().
clamp
(
-
8
,
7
).
to
(
torch
.
int8
).
reshape
(
-
1
,
sf_vec_size
//
2
,
2
)
)
x_int4
=
(
x_int8
[...,
0
]
&
0x0F
)
|
((
x_int8
[...,
1
]
&
0x0F
)
<<
4
)
return
(
x_int4
.
to
(
torch
.
uint8
).
reshape
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
2
),
scales
.
to
(
x
.
dtype
).
reshape
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
sf_vec_size
),
)
def
mxint4_quantize_moe_weights
(
weights_bf16
:
torch
.
Tensor
,
group_size
:
int
=
32
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize MoE weights [e, n, k] to MxInt4 format.
Args:
weights_bf16: BF16 weights of shape [num_experts, out_features, in_features]
group_size: Quantization group size (default: 32)
Returns:
- weights_mxint4: Quantized weights [e, n, k//2] uint8
- scales_mxint4: Quantization scales [e, n, k//group_size] bf16
"""
e
=
weights_bf16
.
shape
[
0
]
weight_list
=
[]
scale_list
=
[]
for
i
in
range
(
e
):
w_q
,
w_s
=
mxint4_quantize
(
weights_bf16
[
i
],
sf_vec_size
=
group_size
)
weight_list
.
append
(
w_q
)
scale_list
.
append
(
w_s
)
return
torch
.
stack
(
weight_list
),
torch
.
stack
(
scale_list
)
__all__
=
[
"mxint4_quantize"
,
"mxint4_quantize_moe_weights"
,
"marlin_quantize_moe_weights"
,
]
def
marlin_quantize_moe_weights
(
weights_bf16
:
torch
.
Tensor
,
group_size
:
int
=
32
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize MoE weights [e, n, k] to Marlin INT4 format.
Args:
weights_bf16: BF16 weights of shape [num_experts, out_features, in_features]
group_size: Quantization group size (default: 32)
Returns:
- weights_marlin: Marlin quantized weights [e, k//8, n] int32
- scales_marlin: Marlin quantization scales [e, k//group_size, n] bf16
"""
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
,
)
e
,
n
,
k
=
weights_bf16
.
shape
weight_list
=
[]
scale_list
=
[]
for
i
in
range
(
e
):
# Transpose for Marlin: [n, k] → [k, n]
w_t
=
weights_bf16
[
i
].
T
.
contiguous
()
_
,
w_q
,
w_s
,
_
,
_
,
_
=
marlin_quantize
(
w_t
,
scalar_types
.
uint4b8
,
group_size
,
act_order
=
False
)
weight_list
.
append
(
w_q
)
scale_list
.
append
(
w_s
)
# Stack to get [e, ...] shape
weights_marlin
=
torch
.
stack
(
weight_list
)
# [e, k // 8, n]
scales_marlin
=
torch
.
stack
(
scale_list
)
# [e, k // group_size, n]
return
weights_marlin
,
scales_marlin
TRTLLM_GEN_AVAILABLE
=
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
)
@
pytest
.
mark
.
skipif
(
not
TRTLLM_GEN_AVAILABLE
,
reason
=
"Skip for non SM100"
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
7168
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
384
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
32
])
def
test_marlin_vs_trtllm_mxint4_moe_kimik2
(
monkeypatch
,
m
,
n
,
k
,
e
,
topk
,
group_size
):
"""Compare Marlin INT4 MoE vs FlashInfer TRT-LLM MXINT4 MoE.
Uses mxint4_quantize() to generate common INT4 weights + BF16 scales,
then runs both Marlin and TRT-LLM kernels and compares outputs.
"""
pytest
.
importorskip
(
"flashinfer"
)
monkeypatch
.
setenv
(
"VLLM_USE_FLASHINFER_MOE_INT4"
,
"1"
)
torch
.
cuda
.
manual_seed
(
0
)
dtype
=
torch
.
bfloat16
# DeepSeekV3 routing config (from Kimi-K2-Thinking config.json)
n_group
=
1
# n_group from model config
topk_group
=
1
# topk_group from model config
routed_scaling
=
2.827
# routed_scaling_factor from model config
# Input - realistic activation range for LLM (after LayerNorm: mean~0, std~1)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
*
0.5
# Generate routing logits and bias (DeepSeekV3 expects float logits)
# Realistic ranges: logits typically [-3, 3], bias [-2, 2]
routing_logits
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
1.5
routing_bias
=
torch
.
randn
(
e
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.8
# 1. Generate BF16 weights (SHARED between both paths)
# Realistic weight initialization: Xavier/Glorot uniform scaling
# std = sqrt(2 / (fan_in + fan_out))
std_w1
=
(
2.0
/
(
k
+
2
*
n
))
**
0.5
std_w2
=
(
2.0
/
(
n
+
k
))
**
0.5
w1_bf16
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
*
std_w1
w2_bf16
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
*
std_w2
# === Path 1: TRT-LLM FlashInfer MXINT4 MoE ===
# Similar to: if self.use_flashinfer_mxint4_moe
# Quantize using MXINT4 method (signed INT4)
w1_int4
,
w1_scales
=
mxint4_quantize_moe_weights
(
w1_bf16
,
group_size
)
w2_int4
,
w2_scales
=
mxint4_quantize_moe_weights
(
w2_bf16
,
group_size
)
trtllm_weights
=
prepare_static_weights_for_trtllm_mxint4_moe
(
gemm1_weights
=
w1_int4
,
gemm1_scales
=
w1_scales
,
gemm2_weights
=
w2_int4
,
gemm2_scales
=
w2_scales
,
)
from
flashinfer
import
RoutingMethodType
from
flashinfer.fused_moe
import
trtllm_mxint4_block_scale_moe
# Routing handled internally by trtllm_mxint4_block_scale_moe
trtllm_output
=
trtllm_mxint4_block_scale_moe
(
routing_logits
=
routing_logits
,
routing_bias
=
routing_bias
.
to
(
torch
.
bfloat16
),
hidden_states
=
a
,
gemm1_weights
=
trtllm_weights
[
"gemm1_weights"
],
gemm1_weights_scale
=
trtllm_weights
[
"gemm1_scales"
],
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
trtllm_weights
[
"gemm2_weights"
],
gemm2_weights_scale
=
trtllm_weights
[
"gemm2_scales"
],
num_experts
=
e
,
top_k
=
topk
,
n_group
=
n_group
,
topk_group
=
topk_group
,
intermediate_size
=
n
,
local_expert_offset
=
0
,
local_num_experts
=
e
,
routed_scaling_factor
=
routed_scaling
,
routing_method_type
=
RoutingMethodType
.
DeepSeekV3
,
enable_pdl
=
None
,
output
=
None
,
tune_max_num_tokens
=
8192
,
).
to
(
dtype
)
# === Path 2: Marlin INT4 MoE ===
# Similar to: else (non-flashinfer path)
# Quantize using Marlin's method (UINT4b8)
w1_marlin
,
w1_scales_marlin
=
marlin_quantize_moe_weights
(
w1_bf16
,
group_size
)
w2_marlin
,
w2_scales_marlin
=
marlin_quantize_moe_weights
(
w2_bf16
,
group_size
)
# Use production routing kernel (same as router.select_experts internally uses)
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
a
,
gating_output
=
routing_logits
,
topk
=
topk
,
renormalize
=
False
,
# DeepSeekV3 doesn't renormalize
num_expert_group
=
n_group
,
topk_group
=
topk_group
,
scoring_func
=
"sigmoid"
,
# DeepSeekV3 uses sigmoid
routed_scaling_factor
=
routed_scaling
,
e_score_correction_bias
=
routing_bias
,
)
marlin_output
=
fused_marlin_moe
(
a
,
w1_marlin
,
w2_marlin
,
None
,
None
,
w1_scales_marlin
,
w2_scales_marlin
,
None
,
# gating_output not needed when topk_weights/ids provided
topk_weights
,
topk_ids
,
global_num_experts
=
e
,
expert_map
=
None
,
global_scale1
=
None
,
global_scale2
=
None
,
g_idx1
=
None
,
g_idx2
=
None
,
input_global_scale1
=
None
,
input_global_scale2
=
None
,
sort_indices1
=
None
,
sort_indices2
=
None
,
w1_zeros
=
None
,
w2_zeros
=
None
,
input_dtype
=
dtype
,
quant_type_id
=
scalar_types
.
uint4b8
.
id
,
is_k_full
=
True
,
)
# Sanity check: manually compute BF16 reference for comparison
# Use same routing as Marlin path for consistency
bf16_output
=
torch
.
zeros
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
for
token_idx
in
range
(
m
):
for
expert_rank
in
range
(
topk
):
expert_id
=
topk_ids
[
token_idx
,
expert_rank
].
item
()
weight
=
topk_weights
[
token_idx
,
expert_rank
].
item
()
# w1: [2*n, k] @ [k] -> [2*n]
up_gate
=
a
[
token_idx
]
@
w1_bf16
[
expert_id
].
T
# [2*n]
gate
,
up
=
up_gate
.
chunk
(
2
,
dim
=
0
)
intermediate
=
torch
.
nn
.
functional
.
silu
(
gate
)
*
up
# [n]
# w2: [k, n] @ [n] -> [k]
expert_out
=
intermediate
@
w2_bf16
[
expert_id
].
T
# [k]
bf16_output
[
token_idx
]
+=
weight
*
expert_out
# Compare against BF16 reference.
torch
.
testing
.
assert_close
(
marlin_output
,
bf16_output
,
atol
=
0.3
,
rtol
=
1.0
)
torch
.
testing
.
assert_close
(
trtllm_output
,
bf16_output
,
atol
=
0.3
,
rtol
=
1.0
)
# Compare against each other for sanity.
# Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause
# some differences
torch
.
testing
.
assert_close
(
marlin_output
,
trtllm_output
,
atol
=
0.3
,
rtol
=
6.0
)
vllm/envs.py
View file @
c3a9752b
...
...
@@ -174,6 +174,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_FP16
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_FP8
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_FP4
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_INT4
:
bool
=
False
VLLM_FLASHINFER_MOE_BACKEND
:
Literal
[
"throughput"
,
"latency"
,
"masked_gemm"
]
=
(
"latency"
)
...
...
@@ -1240,18 +1241,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER"
,
"0"
))
),
# Allow use of FlashInfer MoE kernels for fused moe ops.
# Allow use of FlashInfer
BF16
MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP16"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_MOE_FP16"
,
"0"
))
),
# Allow use of FlashInfer MoE kernels for fused moe ops.
# Allow use of FlashInfer
FP8
MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_MOE_FP8"
,
"0"
))
),
# Allow use of FlashInfer
CUTLASS
kernels for fused moe ops.
# Allow use of FlashInfer
NVFP4 MoE
kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP4"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_MOE_FP4"
,
"0"
))
),
# Allow use of FlashInfer MxInt4 MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_INT4"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_MOE_INT4"
,
"0"
))
),
# If set to 1, use the FlashInfer
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8"
:
lambda
:
bool
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
c3a9752b
...
...
@@ -1138,6 +1138,11 @@ class FusedMoE(CustomOp):
return
False
if
return_success
else
None
# Hereafter, `expert_id` is local physical id
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
...
...
@@ -1145,7 +1150,10 @@ class FusedMoE(CustomOp):
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
):
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
is_transposed
:
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
else
:
loaded_weight
=
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but got
{
shard_id
}
."
)
...
...
@@ -1183,10 +1191,6 @@ class FusedMoE(CustomOp):
)
return
True
if
return_success
else
None
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
is_transposed
:
shard_dim
=
int
(
not
shard_dim
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
c3a9752b
...
...
@@ -63,6 +63,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe
,
flashinfer_trtllm_fp4_routed_moe
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe
import
(
flashinfer_trtllm_mxint4_moe
,
is_flashinfer_mxint4_moe_available
,
prepare_static_weights_for_trtllm_mxint4_moe
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_fi_trtllm_fp8_per_tensor_moe
,
)
...
...
@@ -1247,8 +1252,89 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self
.
actorder
=
weight_quant
.
actorder
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
self
.
num_bits
]
self
.
use_marlin
=
True
self
.
marlin_input_dtype
=
get_marlin_input_dtype
(
layer_name
)
self
.
use_flashinfer_mxint4_moe
=
(
is_flashinfer_mxint4_moe_available
()
and
self
.
group_size
==
32
and
weight_quant
.
num_bits
==
4
)
self
.
kernel_backend
=
(
"Flashinfer"
if
self
.
use_flashinfer_mxint4_moe
else
"Marlin"
)
logger
.
info_once
(
f
"Using
{
self
.
kernel_backend
}
backend for WNA16 MoE "
f
"(group_size=
{
self
.
group_size
}
, num_bits=
{
self
.
num_bits
}
)"
,
scope
=
"local"
,
)
def
get_weight_shape
(
self
,
weight_name
:
str
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
num_groups_w2
:
int
|
None
=
None
,
num_groups_w13
:
int
|
None
=
None
,
)
->
tuple
[
int
,
int
,
int
]:
"""
Get the shape of the weight based on the weight name, number of experts
hidden size, intermediate size per partition, number of groups for w2,
and number of groups for w13. Pass in num_groups_w2 and num_groups_w13
for weight scales.
"""
if
weight_name
==
"w13_scale"
:
assert
num_groups_w13
is
not
None
,
(
"num_groups_w13 must be provided for weight scales"
)
if
weight_name
==
"w2_scale"
:
assert
num_groups_w2
is
not
None
,
(
"num_groups_w2 must be provided for weight scales"
)
w13_num_shards
=
2
if
self
.
moe
.
is_act_and_mul
else
1
shape_map
=
{
"w13_weight"
:
{
"Flashinfer"
:
(
num_experts
,
w13_num_shards
*
intermediate_size_per_partition
,
hidden_size
//
self
.
packed_factor
,
),
"Marlin"
:
(
num_experts
,
hidden_size
//
self
.
packed_factor
,
w13_num_shards
*
intermediate_size_per_partition
,
),
},
"w13_scale"
:
{
"Flashinfer"
:
(
num_experts
,
w13_num_shards
*
intermediate_size_per_partition
,
num_groups_w13
,
),
"Marlin"
:
(
num_experts
,
num_groups_w13
,
w13_num_shards
*
intermediate_size_per_partition
,
),
},
"w2_weight"
:
{
"Flashinfer"
:
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
self
.
packed_factor
,
),
"Marlin"
:
(
num_experts
,
intermediate_size_per_partition
//
self
.
packed_factor
,
hidden_size
,
),
},
"w2_scale"
:
{
"Flashinfer"
:
(
num_experts
,
hidden_size
,
num_groups_w2
),
"Marlin"
:
(
num_experts
,
num_groups_w2
,
hidden_size
),
},
}
return
shape_map
[
weight_name
][
self
.
kernel_backend
]
def
create_weights
(
self
,
...
...
@@ -1260,19 +1346,23 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
**
extra_weight_attrs
,
):
intermediate_size_full
=
extra_weight_attrs
.
pop
(
"intermediate_size_full"
)
w13_num_shards
=
2
if
self
.
moe
.
is_act_and_mul
else
1
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
is_transposed
=
self
.
kernel_backend
!=
"Flashinfer"
extra_weight_attrs
.
update
(
{
"is_transposed"
:
True
,
"quant_method"
:
self
.
strategy
}
{
"is_transposed"
:
is_transposed
,
"quant_method"
:
self
.
strategy
}
)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
packed_factor
,
w13_num_shards
*
intermediate_size_per_partition
,
*
self
.
get_weight_shape
(
"w13_weight"
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
),
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
...
...
@@ -1282,9 +1372,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
//
self
.
packed_factor
,
hidden_size
,
*
self
.
get_weight_shape
(
"w2_weight"
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
),
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
...
...
@@ -1315,9 +1408,13 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w13
,
w13_num_shards
*
intermediate_size_per_partition
,
*
self
.
get_weight_shape
(
"w13_scale"
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
num_groups_w13
=
num_groups_w13
,
),
dtype
=
params_dtype
,
),
requires_grad
=
False
,
...
...
@@ -1326,7 +1423,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
torch
.
ones
(
*
self
.
get_weight_shape
(
"w2_scale"
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
num_groups_w2
=
num_groups_w2
,
),
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_scale
)
...
...
@@ -1396,6 +1502,27 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
num_experts
=
layer
.
w13_weight_g_idx
.
shape
[
0
]
device
=
layer
.
w13_weight_g_idx
.
device
if
self
.
kernel_backend
==
"Flashinfer"
:
dict_weights_mxint4
=
prepare_static_weights_for_trtllm_mxint4_moe
(
layer
.
w13_weight_packed
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_packed
,
layer
.
w2_weight_scale
,
)
replace_parameter
(
layer
,
"w13_weight_packed"
,
dict_weights_mxint4
[
"gemm1_weights"
]
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
dict_weights_mxint4
[
"gemm1_scales"
]
)
replace_parameter
(
layer
,
"w2_weight_packed"
,
dict_weights_mxint4
[
"gemm2_weights"
]
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
dict_weights_mxint4
[
"gemm2_scales"
]
)
return
None
is_a_8bit
=
(
self
.
marlin_input_dtype
is
not
None
and
self
.
marlin_input_dtype
.
itemsize
==
1
...
...
@@ -1560,6 +1687,35 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
is_k_full
=
self
.
is_k_full
,
)
@
property
def
is_monolithic
(
self
)
->
bool
:
return
self
.
kernel_backend
==
"Flashinfer"
def
apply_monolithic
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
kernel_backend
==
"Flashinfer"
return
flashinfer_trtllm_mxint4_moe
(
x
=
x
,
router_logits
=
router_logits
,
w13_weight_packed
=
layer
.
w13_weight_packed
,
w13_weight_scale
=
layer
.
w13_weight_scale
,
w2_weight_packed
=
layer
.
w2_weight_packed
,
w2_weight_scale
=
layer
.
w2_weight_scale
,
global_num_experts
=
layer
.
global_num_experts
,
top_k
=
layer
.
top_k
,
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
,
local_num_experts
=
layer
.
local_num_experts
,
ep_rank
=
layer
.
ep_rank
,
num_expert_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
e_score_correction_bias
=
layer
.
e_score_correction_bias
,
routing_method_type
=
layer
.
routing_method_type
,
)
def
apply
(
self
,
layer
:
FusedMoE
,
...
...
@@ -1567,6 +1723,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
kernel_backend
==
"Marlin"
return
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py
0 → 100644
View file @
c3a9752b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility helpers for MxInt4 + FlashInfer fused-MoE path"""
import
functools
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_trtllm_fused_moe
__all__
=
[
"prepare_static_weights_for_trtllm_mxint4_moe"
,
"flashinfer_trtllm_mxint4_moe"
,
"is_flashinfer_mxint4_moe_available"
,
]
logger
=
init_logger
(
__name__
)
@
functools
.
cache
def
is_flashinfer_mxint4_moe_available
()
->
bool
:
"""Return `True` when FlashInfer MxInt4 kernels can be used."""
return
(
envs
.
VLLM_USE_FLASHINFER_MOE_INT4
and
has_flashinfer_trtllm_fused_moe
()
and
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
)
def
prepare_static_weights_for_trtllm_mxint4_moe
(
gemm1_weights
:
torch
.
Tensor
,
gemm1_scales
:
torch
.
Tensor
,
gemm2_weights
:
torch
.
Tensor
,
gemm2_scales
:
torch
.
Tensor
,
)
->
dict
[
str
,
torch
.
Tensor
]:
"""
Prepare MxInt4 weights for TRT-LLM kernel.
Input:
gemm1_weights: [num_experts, 2*intermediate_size, hidden_size//8] int32
(checkpoint uint4b8 packed) or uint8 (already packed signed int4)
gemm1_scales: [num_experts, 2*intermediate_size, hidden_size//32] bf16
gemm2_weights: [num_experts, hidden_size, intermediate_size//8] int32
(checkpoint uint4b8 packed) or uint8 (already packed signed int4)
gemm2_scales: [num_experts, hidden_size, intermediate_size//32] bf16
Returns:
Dict with keys 'gemm1_weights', 'gemm1_scales', 'gemm2_weights',
'gemm2_scales' containing shuffled/packed tensors ready for kernel
"""
from
flashinfer
import
block_scale_interleave
from
flashinfer.fused_moe
import
(
convert_to_block_layout
,
)
from
flashinfer.fused_moe.core
import
(
_maybe_get_cached_w3_w1_permute_indices
,
get_w2_permute_indices_with_cache
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
reorder_w1w3_to_w3w1
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
convert_packed_uint4b8_to_signed_int4_inplace
,
)
device
=
gemm1_weights
.
device
assert
gemm1_weights
.
ndim
==
3
,
(
f
"Expected a 3D gemm1_weights tensor, got
{
gemm1_weights
.
shape
}
"
)
assert
gemm1_scales
.
ndim
==
3
,
(
f
"Expected a 3D gemm1_scales tensor, got
{
gemm1_scales
.
shape
}
"
)
assert
gemm2_weights
.
ndim
==
3
,
(
f
"Expected a 3D gemm2_weights tensor, got
{
gemm2_weights
.
shape
}
"
)
assert
gemm2_scales
.
ndim
==
3
,
(
f
"Expected a 3D gemm2_scales tensor, got
{
gemm2_scales
.
shape
}
"
)
# Convert checkpoint format (uint4b8 in int32) to signed int4
# Checkpoint stores INT4 as unsigned [0, 15], kernel expects signed [-8, 7]
if
gemm1_weights
.
dtype
==
torch
.
int32
and
gemm2_weights
.
dtype
==
torch
.
int32
:
convert_packed_uint4b8_to_signed_int4_inplace
(
gemm1_weights
)
convert_packed_uint4b8_to_signed_int4_inplace
(
gemm2_weights
)
gemm1_weights
,
gemm1_scales
=
reorder_w1w3_to_w3w1
(
gemm1_weights
,
gemm1_scales
,
dim
=-
2
)
_cache_permute_indices
:
dict
[
torch
.
Size
,
torch
.
Tensor
]
=
{}
num_experts
=
gemm1_weights
.
shape
[
0
]
# Convert quantized weights to proper formats -
gemm1_weights_mxint4
=
gemm1_weights
.
view
(
torch
.
uint8
)
assert
gemm1_scales
.
dtype
==
torch
.
bfloat16
gemm2_weights_mxint4
=
gemm2_weights
.
view
(
torch
.
uint8
)
assert
gemm2_scales
.
dtype
==
torch
.
bfloat16
epilogue_tile_m
=
128
gemm1_weights_mxint4_shuffled
=
[]
gemm1_scales_shuffled
=
[]
gemm2_weights_mxint4_shuffled
=
[]
gemm2_scales_shuffled
=
[]
for
i
in
range
(
num_experts
):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices
=
_maybe_get_cached_w3_w1_permute_indices
(
_cache_permute_indices
,
gemm1_weights_mxint4
[
i
],
epilogue_tile_m
,
)
gemm1_weights_shuffled
=
gemm1_weights_mxint4
[
i
][
permute_indices
.
to
(
gemm1_weights
.
device
)
].
contiguous
()
permute_sf_indices
=
_maybe_get_cached_w3_w1_permute_indices
(
_cache_permute_indices
,
gemm1_scales
[
i
],
epilogue_tile_m
,
num_elts_per_sf
=
32
,
).
to
(
device
)
gemm1_scales_shuffled
.
append
(
block_scale_interleave
(
gemm1_scales
[
i
][
permute_sf_indices
].
contiguous
())
)
permute_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
gemm2_weights_mxint4
[
i
],
epilogue_tile_m
,
)
gemm2_weights_shuffled
=
gemm2_weights_mxint4
[
i
][
permute_indices
.
to
(
gemm2_weights
.
device
)
].
contiguous
()
permute_sf_indices
=
get_w2_permute_indices_with_cache
(
_cache_permute_indices
,
gemm2_scales
[
i
],
epilogue_tile_m
,
num_elts_per_sf
=
16
,
)
gemm2_scales_shuffled
.
append
(
block_scale_interleave
(
gemm2_scales
[
i
][
permute_sf_indices
.
to
(
gemm2_scales
.
device
)].
contiguous
()
)
)
block_k
=
128
gemm1_weights_shuffled
=
convert_to_block_layout
(
gemm1_weights_shuffled
.
view
(
torch
.
uint8
),
block_k
)
gemm2_weights_shuffled
=
convert_to_block_layout
(
gemm2_weights_shuffled
.
view
(
torch
.
uint8
),
block_k
)
gemm1_weights_mxint4_shuffled
.
append
(
gemm1_weights_shuffled
)
gemm2_weights_mxint4_shuffled
.
append
(
gemm2_weights_shuffled
)
gemm1_weights_mxint4_shuffled
=
torch
.
stack
(
gemm1_weights_mxint4_shuffled
)
gemm2_weights_mxint4_shuffled
=
torch
.
stack
(
gemm2_weights_mxint4_shuffled
)
gemm1_scales_shuffled
=
torch
.
stack
(
gemm1_scales_shuffled
).
view
(
torch
.
bfloat16
)
gemm2_scales_shuffled
=
torch
.
stack
(
gemm2_scales_shuffled
).
view
(
torch
.
bfloat16
)
return
{
"gemm1_weights"
:
gemm1_weights_mxint4_shuffled
,
"gemm1_scales"
:
gemm1_scales_shuffled
,
"gemm2_weights"
:
gemm2_weights_mxint4_shuffled
,
"gemm2_scales"
:
gemm2_scales_shuffled
,
}
def
flashinfer_trtllm_mxint4_moe
(
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
w13_weight_packed
:
torch
.
Tensor
,
w13_weight_scale
:
torch
.
Tensor
,
w2_weight_packed
:
torch
.
Tensor
,
w2_weight_scale
:
torch
.
Tensor
,
global_num_experts
:
int
,
top_k
:
int
,
intermediate_size_per_partition
:
int
,
local_num_experts
:
int
,
ep_rank
:
int
=
0
,
num_expert_group
:
int
|
None
=
None
,
topk_group
:
int
|
None
=
None
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
routing_method_type
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Apply FlashInfer TensorRT-LLM MxInt4 MoE kernel.
Args:
x: Input hidden states. dtype: bfloat16
router_logits: Router logits for expert selection. dtype: bfloat16/float32
w13_weight_packed: Packed gate+up weights. dtype: uint8
w13_weight_scale: Scales for gate+up weights. dtype: bfloat16
w2_weight_packed: Packed down weights. dtype: uint8
w2_weight_scale: Scales for down weights. dtype: bfloat16
global_num_experts: Total number of experts across all ranks
top_k: Number of experts to select per token
intermediate_size_per_partition: Intermediate size per partition
local_num_experts: Number of experts on this rank
ep_rank: Expert parallelism rank (default: 0)
num_expert_group: Number of expert groups (default: None -> 0)
topk_group: Top-k within groups (default: None -> 0)
e_score_correction_bias: Optional routing bias. dtype: bfloat16
routing_method_type: FlashInfer RoutingMethodType enum value
Returns:
Output tensor from MoE layer. dtype: same as x (bfloat16)
"""
from
flashinfer
import
RoutingMethodType
from
flashinfer.fused_moe
import
trtllm_mxint4_block_scale_moe
assert
x
.
dtype
==
torch
.
bfloat16
,
f
"x dtype must be bfloat16, got
{
x
.
dtype
}
"
assert
w13_weight_packed
.
dtype
==
torch
.
uint8
,
(
f
"w13_weight_packed dtype must be uint8, got
{
w13_weight_packed
.
dtype
}
"
)
assert
w13_weight_scale
.
dtype
==
torch
.
bfloat16
,
(
f
"w13_weight_scale dtype must be bfloat16, got
{
w13_weight_scale
.
dtype
}
"
)
assert
w2_weight_packed
.
dtype
==
torch
.
uint8
,
(
f
"w2_weight_packed dtype must be uint8, got
{
w2_weight_packed
.
dtype
}
"
)
assert
w2_weight_scale
.
dtype
==
torch
.
bfloat16
,
(
f
"w2_weight_scale dtype must be bfloat16, got
{
w2_weight_scale
.
dtype
}
"
)
routing_bias
=
None
if
e_score_correction_bias
is
not
None
:
routing_bias
=
e_score_correction_bias
.
to
(
torch
.
bfloat16
)
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
:
router_logits
=
router_logits
.
to
(
torch
.
float32
)
out
=
trtllm_mxint4_block_scale_moe
(
routing_logits
=
router_logits
,
routing_bias
=
routing_bias
,
hidden_states
=
x
,
gemm1_weights
=
w13_weight_packed
.
data
,
gemm1_weights_scale
=
w13_weight_scale
.
data
,
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
w2_weight_packed
.
data
,
gemm2_weights_scale
=
w2_weight_scale
.
data
,
num_experts
=
global_num_experts
,
top_k
=
top_k
,
n_group
=
num_expert_group
if
num_expert_group
is
not
None
else
0
,
topk_group
=
topk_group
if
topk_group
is
not
None
else
0
,
intermediate_size
=
intermediate_size_per_partition
,
local_expert_offset
=
ep_rank
*
local_num_experts
,
local_num_experts
=
local_num_experts
,
routed_scaling_factor
=
None
,
routing_method_type
=
routing_method_type
,
enable_pdl
=
None
,
output
=
None
,
tune_max_num_tokens
=
8192
,
).
to
(
x
.
dtype
)
return
out
vllm/utils/flashinfer.py
View file @
c3a9752b
...
...
@@ -129,12 +129,11 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper(
"flashinfer"
,
"scaled_fp4_grouped_quantize"
)
nvfp4_block_scale_interleave
=
_lazy_import_wrapper
(
"flashinfer
"
,
"nvfp4_
block_scale_interleave"
"flashinfer
.fp4_quantization"
,
"
block_scale_interleave"
)
trtllm_fp4_block_scale_moe
=
_lazy_import_wrapper
(
"flashinfer"
,
"trtllm_fp4_block_scale_moe"
)
# Special case for autotune since it returns a context manager
autotune
=
_lazy_import_wrapper
(
"flashinfer.autotuner"
,
...
...
@@ -196,6 +195,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool:
(
"flashinfer.fused_moe"
,
"trtllm_fp8_block_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_fp8_per_tensor_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_fp4_block_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_mxint4_block_scale_moe"
),
]
for
module_name
,
attr_name
in
required_functions
:
mod
=
_get_submodule
(
module_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment