Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9609b1f1
Unverified
Commit
9609b1f1
authored
Feb 24, 2026
by
danisereb
Committed by
GitHub
Feb 24, 2026
Browse files
Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 (#35053)
Signed-off-by:
Daniel Serebrenik
<
daserebrenik@nvidia.com
>
parent
a0c70816
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
230 additions
and
11 deletions
+230
-11
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+50
-10
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
+103
-1
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+77
-0
No files found.
vllm/model_executor/layers/quantization/modelopt.py
View file @
9609b1f1
...
...
@@ -70,6 +70,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_VALUE_DTYPE
,
Mxfp8LinearBackend
,
Mxfp8LinearOp
,
swizzle_mxfp8_scale
,
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
apply_nvfp4_linear
,
...
...
@@ -1689,9 +1690,9 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
"Dynamic quantization is not supported."
)
backend
:
Mxfp8LinearBackend
=
Mxfp8LinearBackend
.
EMULATION
self
.
mxfp8_linear_op
=
Mxfp8LinearOp
(
backend
=
backend
)
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
backend
.
value
)
self
.
backend
:
Mxfp8LinearBackend
=
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
self
.
mxfp8_linear_op
=
Mxfp8LinearOp
(
backend
=
self
.
backend
)
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
self
.
backend
.
value
)
def
create_weights
(
self
,
...
...
@@ -1749,7 +1750,38 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
_process_weights_after_loading_scale_2d
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Not swizzled - MXFP8 GEMM emulation"""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
_process_weights_after_loading_scale_1d
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
# 2D weight scale
weight_scale
=
layer
.
weight_scale
.
data
# Swizzle the weight scales
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale_2d
=
weight_scale
[:
N
,
:
scale_k
].
contiguous
()
weight_scale_swizzled
=
swizzle_mxfp8_scale
(
weight_scale_2d
,
M
=
N
,
K
=
K
)
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale_swizzled
.
contiguous
(),
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Validate weight tensor
if
layer
.
weight
.
ndim
!=
2
:
raise
ValueError
(
f
"MXFP8 weight must be 2D tensor [N, K], got
{
layer
.
weight
.
ndim
}
D "
...
...
@@ -1763,15 +1795,23 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
f
"quantized with MXFP8."
)
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
# Validate weight scale tensor (should be 2D, not swizzled)
assert
layer
.
weight_scale
.
ndim
==
2
,
(
f
"MXFP8 weight scale must be 2D, got
{
layer
.
weight_scale
.
ndim
}
D"
)
assert
layer
.
weight_scale
.
dtype
==
MXFP8_SCALE_DTYPE
,
(
f
"MXFP8 weight scale must be
{
MXFP8_SCALE_DTYPE
}
,"
f
" got
{
layer
.
weight_scale
.
dtype
}
"
)
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
if
self
.
backend
==
Mxfp8LinearBackend
.
EMULATION
:
# Swizzled layout is not used
self
.
_process_weights_after_loading_scale_2d
(
layer
)
return
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
assert
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
# Swizzled layout is required for Flashinfer CUTLASS
self
.
_process_weights_after_loading_scale_1d
(
layer
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
View file @
9609b1f1
...
...
@@ -6,6 +6,7 @@ from enum import Enum
import
torch
from
vllm.logger
import
init_logger
from
vllm.utils
import
flashinfer
as
vllm_flashinfer
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -13,6 +14,7 @@ logger = init_logger(__name__)
class
Mxfp8LinearBackend
(
Enum
):
EMULATION
=
"emulation"
FLASHINFER_CUTLASS
=
"flashinfer-cutlass"
# MXFP8 constants
...
...
@@ -21,6 +23,30 @@ MXFP8_SCALE_DTYPE = torch.uint8
MXFP8_BLOCK_SIZE
=
32
def
swizzle_mxfp8_scale
(
sf
:
torch
.
Tensor
,
M
:
int
,
K
:
int
)
->
torch
.
Tensor
:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size
=
MXFP8_BLOCK_SIZE
# 32 for MXFP8
factor
=
scaling_vector_size
*
4
# 128
num_m_tiles
=
(
M
+
127
)
//
128
num_k_tiles
=
(
K
+
factor
-
1
)
//
factor
m_padded
=
num_m_tiles
*
128
k_scale_padded
=
num_k_tiles
*
4
scale_cols
=
K
//
scaling_vector_size
sf_padded
=
torch
.
zeros
(
(
m_padded
,
k_scale_padded
),
dtype
=
sf
.
dtype
,
device
=
sf
.
device
)
sf_padded
[:
M
,
:
scale_cols
]
=
sf
sf_reshaped
=
sf_padded
.
view
(
num_m_tiles
,
4
,
32
,
num_k_tiles
,
4
)
sf_swizzled
=
sf_reshaped
.
transpose
(
1
,
3
)
return
sf_swizzled
.
contiguous
().
view
(
-
1
)
def
_mxfp8_e4m3_quantize_impl
(
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -108,7 +134,7 @@ class Mxfp8LinearOp:
self
.
backend
=
backend
def
apply
(
def
_
apply
_emulation
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -132,3 +158,79 @@ class Mxfp8LinearOp:
output
=
torch
.
nn
.
functional
.
linear
(
input
,
weight_bf16
,
bias
)
return
output
.
to
(
out_dtype
)
def
_apply_flashinfer_cutlass
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
N
,
K
=
weight
.
shape
input_shape
=
input
.
shape
input_2d
=
input
.
view
(
-
1
,
K
)
M_orig
=
input_2d
.
shape
[
0
]
# Minimum dimension size for F8_128x4 block scaling layout
min_dim
=
128
assert
min_dim
<=
K
,
(
f
"mm_mxfp8 requires K >=
{
min_dim
}
, got K=
{
K
}
. "
f
"in_features is too small for mm_mxfp8."
)
assert
K
%
MXFP8_BLOCK_SIZE
==
0
,
(
f
"mm_mxfp8 requires K to be divisible by
{
MXFP8_BLOCK_SIZE
}
, got K=
{
K
}
."
)
assert
min_dim
<=
N
,
(
f
"mm_mxfp8 requires N >=
{
min_dim
}
, got N=
{
N
}
. "
f
"out_features is too small for mm_mxfp8."
)
M_padded
=
((
M_orig
+
min_dim
-
1
)
//
min_dim
)
*
min_dim
if
M_padded
!=
M_orig
:
pad_rows
=
M_padded
-
M_orig
input_2d
=
torch
.
nn
.
functional
.
pad
(
input_2d
,
(
0
,
0
,
0
,
pad_rows
))
input_mxfp8
,
input_scale
=
mxfp8_e4m3_quantize
(
input_2d
,
is_sf_swizzled_layout
=
True
,
# Swizzled for best accuracy
)
if
not
weight
.
is_contiguous
():
weight
=
weight
.
contiguous
()
output
=
vllm_flashinfer
.
mm_mxfp8
(
input_mxfp8
,
weight
.
t
(),
input_scale
,
weight_scale
,
out_dtype
=
out_dtype
,
backend
=
"cutlass"
,
)
if
M_padded
!=
M_orig
:
output
=
output
[:
M_orig
,
:]
if
bias
is
not
None
:
output
=
output
+
bias
output_shape
=
(
*
input_shape
[:
-
1
],
N
)
return
output
.
view
(
output_shape
)
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
self
.
backend
==
Mxfp8LinearBackend
.
EMULATION
:
return
self
.
_apply_emulation
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
)
assert
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
return
self
.
_apply_flashinfer_cutlass
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
)
vllm/utils/flashinfer.py
View file @
9609b1f1
...
...
@@ -553,6 +553,83 @@ if has_flashinfer():
rounded_m
,
rounded_n
,
dtype
=
torch
.
uint8
,
device
=
a
.
device
)
@
torch
.
library
.
custom_op
(
"vllm::mm_mxfp8"
,
mutates_args
=
[],
device_types
=
"cuda"
,
)
def
mm_mxfp8
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
backend
:
str
=
"cutlass"
,
)
->
torch
.
Tensor
:
from
flashinfer
import
mm_mxfp8
as
mm_mxfp8_
return
mm_mxfp8_
(
A
,
B
,
A_scale
,
B_scale
,
out
=
None
,
out_dtype
=
out_dtype
,
backend
=
backend
,
)
@
torch
.
library
.
register_fake
(
"vllm::mm_mxfp8"
,
)
def
mm_mxfp8_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
backend
:
str
=
"cutlass"
,
)
->
torch
.
Tensor
:
# A is [m, k], B is [k, n] -> output [m, n]
return
torch
.
empty
(
A
.
shape
[
0
],
B
.
shape
[
1
],
dtype
=
out_dtype
,
device
=
A
.
device
)
def
flashinfer_mm_mxfp8
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
block_scale_a
:
torch
.
Tensor
,
block_scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
backend
:
str
=
"cutlass"
,
)
->
torch
.
Tensor
:
"""MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.
Takes non-transposed weights and handles transpose internally.
CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
performance and accuracy. Both input and weight scales should be in
swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
"""
# a shape [M, K]
# b shape [K, N]
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
assert
a
.
shape
[
1
]
==
b
.
shape
[
1
]
# K dimension must match
if
block_scale_b
.
ndim
!=
1
:
raise
ValueError
(
"mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
f
"got shape=
{
tuple
(
block_scale_b
.
shape
)
}
"
)
# Output tensor [M, N]
return
mm_mxfp8
(
a
,
b
.
t
(),
# Transpose weight: [N, K] -> [K, N]
block_scale_a
,
block_scale_b
,
out_dtype
,
backend
=
backend
,
)
def
flashinfer_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment