Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11e2375f
Unverified
Commit
11e2375f
authored
Apr 10, 2026
by
Michael Goin
Committed by
GitHub
Apr 10, 2026
Browse files
[Refactor] Move MXFP8 GEMM management into MxFp8LinearKernel (#39205)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
fc645f1a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
365 additions
and
258 deletions
+365
-258
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+69
-0
vllm/model_executor/kernels/linear/mxfp8/Mxfp8LinearKernel.py
.../model_executor/kernels/linear/mxfp8/Mxfp8LinearKernel.py
+56
-0
vllm/model_executor/kernels/linear/mxfp8/__init__.py
vllm/model_executor/kernels/linear/mxfp8/__init__.py
+12
-0
vllm/model_executor/kernels/linear/mxfp8/emulation.py
vllm/model_executor/kernels/linear/mxfp8/emulation.py
+60
-0
vllm/model_executor/kernels/linear/mxfp8/flashinfer.py
vllm/model_executor/kernels/linear/mxfp8/flashinfer.py
+103
-0
vllm/model_executor/kernels/linear/mxfp8/marlin.py
vllm/model_executor/kernels/linear/mxfp8/marlin.py
+53
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+4
-4
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+4
-13
vllm/model_executor/layers/quantization/mxfp8.py
vllm/model_executor/layers/quantization/mxfp8.py
+4
-14
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
+0
-227
No files found.
vllm/model_executor/kernels/linear/__init__.py
View file @
11e2375f
...
@@ -58,6 +58,19 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
...
@@ -58,6 +58,19 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel
,
XPUW4A8IntLinearKernel
,
XPUwNa16LinearKernel
,
XPUwNa16LinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.mxfp8
import
(
Mxfp8LinearKernel
,
Mxfp8LinearLayerConfig
,
)
from
vllm.model_executor.kernels.linear.mxfp8.emulation
import
(
EmulationMxfp8LinearKernel
,
)
from
vllm.model_executor.kernels.linear.mxfp8.flashinfer
import
(
FlashInferCutlassMxfp8LinearKernel
,
)
from
vllm.model_executor.kernels.linear.mxfp8.marlin
import
(
MarlinMxfp8LinearKernel
,
)
from
vllm.model_executor.kernels.linear.nvfp4
import
(
from
vllm.model_executor.kernels.linear.nvfp4
import
(
NvFp4LinearKernel
,
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
,
NvFp4LinearLayerConfig
,
...
@@ -221,6 +234,17 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
...
@@ -221,6 +234,17 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
}
}
# in priority/performance order (when available)
# in priority/performance order (when available)
_POSSIBLE_MXFP8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
Mxfp8LinearKernel
]]]
=
{
PlatformEnum
.
CUDA
:
[
FlashInferCutlassMxfp8LinearKernel
,
MarlinMxfp8LinearKernel
,
EmulationMxfp8LinearKernel
,
],
PlatformEnum
.
ROCM
:
[
EmulationMxfp8LinearKernel
,
],
}
_POSSIBLE_NVFP4_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
NvFp4LinearKernel
]]]
=
{
_POSSIBLE_NVFP4_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
NvFp4LinearKernel
]]]
=
{
PlatformEnum
.
CUDA
:
[
PlatformEnum
.
CUDA
:
[
FlashInferCutlassNvFp4LinearKernel
,
FlashInferCutlassNvFp4LinearKernel
,
...
@@ -482,6 +506,41 @@ def choose_mp_linear_kernel(
...
@@ -482,6 +506,41 @@ def choose_mp_linear_kernel(
)
)
def
init_mxfp8_linear_kernel
()
->
Mxfp8LinearKernel
:
"""Select and instantiate the best MXFP8 linear kernel for the
current platform."""
config
=
Mxfp8LinearLayerConfig
()
platform
=
current_platform
.
_enum
possible
=
_POSSIBLE_MXFP8_KERNELS
.
get
(
platform
,
[])
failure_reasons
=
[]
for
kernel_cls
in
possible
:
if
kernel_cls
.
__name__
in
envs
.
VLLM_DISABLED_KERNELS
:
failure_reasons
.
append
(
f
"
{
kernel_cls
.
__name__
}
disabled by environment variable"
)
continue
is_supported
,
reason
=
kernel_cls
.
is_supported
()
if
not
is_supported
:
failure_reasons
.
append
(
f
"
{
kernel_cls
.
__name__
}
:
{
reason
}
"
)
continue
can_implement
,
reason
=
kernel_cls
.
can_implement
(
config
)
if
not
can_implement
:
failure_reasons
.
append
(
f
"
{
kernel_cls
.
__name__
}
:
{
reason
}
"
)
continue
logger
.
info_once
(
"Using %s for MXFP8 GEMM"
,
kernel_cls
.
__name__
)
return
kernel_cls
(
config
)
raise
ValueError
(
"Failed to find a kernel that can implement the "
"MXFP8 linear layer. Reasons:
\n
"
+
"
\n
"
.
join
(
failure_reasons
)
)
def
init_wfp8_a16_linear_kernel
(
def
init_wfp8_a16_linear_kernel
(
weight_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
activation_quant_key
:
QuantKey
,
activation_quant_key
:
QuantKey
,
...
@@ -628,6 +687,10 @@ def register_linear_kernel(
...
@@ -628,6 +687,10 @@ def register_linear_kernel(
if
platform
not
in
_POSSIBLE_FP8_KERNELS
:
if
platform
not
in
_POSSIBLE_FP8_KERNELS
:
_POSSIBLE_FP8_KERNELS
[
platform
]
=
[]
_POSSIBLE_FP8_KERNELS
[
platform
]
=
[]
_POSSIBLE_FP8_KERNELS
[
platform
].
append
(
kernel_class
)
_POSSIBLE_FP8_KERNELS
[
platform
].
append
(
kernel_class
)
elif
kernel_type
==
"mxfp8"
:
if
platform
not
in
_POSSIBLE_MXFP8_KERNELS
:
_POSSIBLE_MXFP8_KERNELS
[
platform
]
=
[]
_POSSIBLE_MXFP8_KERNELS
[
platform
].
append
(
kernel_class
)
elif
kernel_type
==
"nvfp4"
:
elif
kernel_type
==
"nvfp4"
:
if
platform
not
in
_POSSIBLE_NVFP4_KERNELS
:
if
platform
not
in
_POSSIBLE_NVFP4_KERNELS
:
_POSSIBLE_NVFP4_KERNELS
[
platform
]
=
[]
_POSSIBLE_NVFP4_KERNELS
[
platform
]
=
[]
...
@@ -674,6 +737,12 @@ __all__ = [
...
@@ -674,6 +737,12 @@ __all__ = [
"TritonW4A16LinearKernel"
,
"TritonW4A16LinearKernel"
,
"XPUW4A8IntLinearKernel"
,
"XPUW4A8IntLinearKernel"
,
"XPUwNa16LinearKernel"
,
"XPUwNa16LinearKernel"
,
"init_mxfp8_linear_kernel"
,
"Mxfp8LinearKernel"
,
"Mxfp8LinearLayerConfig"
,
"FlashInferCutlassMxfp8LinearKernel"
,
"MarlinMxfp8LinearKernel"
,
"EmulationMxfp8LinearKernel"
,
"CutlassNvFp4LinearKernel"
,
"CutlassNvFp4LinearKernel"
,
"EmulationNvFp4LinearKernel"
,
"EmulationNvFp4LinearKernel"
,
"FbgemmNvFp4LinearKernel"
,
"FbgemmNvFp4LinearKernel"
,
...
...
vllm/model_executor/kernels/linear/mxfp8/Mxfp8LinearKernel.py
0 → 100644
View file @
11e2375f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
import
torch
@
dataclass
class
Mxfp8LinearLayerConfig
:
"""Configuration for an MXFP8 linear layer.
All MXFP8 layers share the same structure: FP8-E4M3 weights with
uint8 (E8M0) per-block scales at block size 32.
"""
pass
class
Mxfp8LinearKernel
(
ABC
):
"""Base class for MXFP8 quantized linear kernels.
Each subclass implements a specific GEMM backend (FlashInfer CUTLASS,
Marlin, emulation).
"""
def
__init__
(
self
,
c
:
Mxfp8LinearLayerConfig
)
->
None
:
assert
self
.
can_implement
(
c
)[
0
]
assert
self
.
is_supported
()[
0
]
self
.
config
=
c
@
classmethod
@
abstractmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
raise
NotImplementedError
@
classmethod
@
abstractmethod
def
can_implement
(
cls
,
c
:
Mxfp8LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
raise
NotImplementedError
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/model_executor/kernels/linear/mxfp8/__init__.py
0 → 100644
View file @
11e2375f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.model_executor.kernels.linear.mxfp8.Mxfp8LinearKernel
import
(
Mxfp8LinearKernel
,
Mxfp8LinearLayerConfig
,
)
__all__
=
[
"Mxfp8LinearKernel"
,
"Mxfp8LinearLayerConfig"
,
]
vllm/model_executor/kernels/linear/mxfp8/emulation.py
0 → 100644
View file @
11e2375f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
MXFP8_SCALE_DTYPE
,
dequant_mxfp8_to_bf16
,
)
from
.Mxfp8LinearKernel
import
Mxfp8LinearKernel
,
Mxfp8LinearLayerConfig
class
EmulationMxfp8LinearKernel
(
Mxfp8LinearKernel
):
"""Software emulation fallback for MXFP8 (dequant to BF16)."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
Mxfp8LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
weight_scale
=
layer
.
weight_scale
if
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
raise
ValueError
(
f
"Emulation backend requires
{
MXFP8_SCALE_DTYPE
}
"
f
"weight_scale dtype, got
{
weight_scale
.
dtype
}
."
)
if
weight_scale
.
ndim
!=
2
:
raise
ValueError
(
f
"Emulation backend requires 2D weight_scale, "
f
"got
{
weight_scale
.
ndim
}
D. "
f
"Ensure process_weights_after_loading was called."
)
weight_bf16
=
dequant_mxfp8_to_bf16
(
layer
.
weight
,
weight_scale
)
output
=
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
,
bias
)
return
output
.
to
(
x
.
dtype
)
vllm/model_executor/kernels/linear/mxfp8/flashinfer.py
0 → 100644
View file @
11e2375f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
mxfp8_e4m3_quantize
,
swizzle_mxfp8_scale
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
flashinfer
as
vllm_flashinfer
from
.Mxfp8LinearKernel
import
Mxfp8LinearKernel
,
Mxfp8LinearLayerConfig
class
FlashInferCutlassMxfp8LinearKernel
(
Mxfp8LinearKernel
):
"""MXFP8 W8A8 GEMM via FlashInfer CUTLASS (SM100+)."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
current_platform
.
has_device_capability
(
100
):
return
True
,
None
return
False
,
"requires >=sm_100 (Blackwell)"
@
classmethod
def
can_implement
(
cls
,
c
:
Mxfp8LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale_2d
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
weight_scale_swizzled
=
swizzle_mxfp8_scale
(
weight_scale_2d
,
M
=
N
,
K
=
K
)
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale_swizzled
.
contiguous
(),
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
out_dtype
=
x
.
dtype
N
,
K
=
weight
.
shape
input_shape
=
x
.
shape
input_2d
=
x
.
view
(
-
1
,
K
)
M_orig
=
input_2d
.
shape
[
0
]
min_dim
=
128
assert
min_dim
<=
K
,
(
f
"mm_mxfp8 requires K >=
{
min_dim
}
, got K=
{
K
}
. "
f
"in_features is too small for mm_mxfp8."
)
assert
K
%
MXFP8_BLOCK_SIZE
==
0
,
(
f
"mm_mxfp8 requires K to be divisible by
{
MXFP8_BLOCK_SIZE
}
, got K=
{
K
}
."
)
assert
min_dim
<=
N
,
(
f
"mm_mxfp8 requires N >=
{
min_dim
}
, got N=
{
N
}
. "
f
"out_features is too small for mm_mxfp8."
)
M_padded
=
((
M_orig
+
min_dim
-
1
)
//
min_dim
)
*
min_dim
if
M_padded
!=
M_orig
:
pad_rows
=
M_padded
-
M_orig
input_2d
=
torch
.
nn
.
functional
.
pad
(
input_2d
,
(
0
,
0
,
0
,
pad_rows
))
input_mxfp8
,
input_scale
=
mxfp8_e4m3_quantize
(
input_2d
,
is_sf_swizzled_layout
=
True
)
if
not
weight
.
is_contiguous
():
weight
=
weight
.
contiguous
()
output
=
vllm_flashinfer
.
mm_mxfp8
(
input_mxfp8
,
weight
.
t
(),
input_scale
,
weight_scale
,
out_dtype
=
out_dtype
,
backend
=
"cutlass"
,
)
if
M_padded
!=
M_orig
:
output
=
output
[:
M_orig
,
:]
if
bias
is
not
None
:
output
=
output
+
bias
output_shape
=
(
*
input_shape
[:
-
1
],
N
)
return
output
.
view
(
output_shape
)
vllm/model_executor/kernels/linear/mxfp8/marlin.py
0 → 100644
View file @
11e2375f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
.Mxfp8LinearKernel
import
Mxfp8LinearKernel
,
Mxfp8LinearLayerConfig
class
MarlinMxfp8LinearKernel
(
Mxfp8LinearKernel
):
"""MXFP8 W8A16 GEMM via Marlin (SM80+)."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
is_fp8_marlin_supported
,
)
if
is_fp8_marlin_supported
():
return
True
,
None
return
False
,
"Marlin FP8 not available"
@
classmethod
def
can_implement
(
cls
,
c
:
Mxfp8LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_mxfp8_layer_for_marlin
,
)
prepare_mxfp8_layer_for_marlin
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_mxfp8_marlin_linear
,
)
return
apply_mxfp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
vllm/model_executor/layers/quantization/fp8.py
View file @
11e2375f
...
@@ -517,10 +517,10 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
...
@@ -517,10 +517,10 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
# TODO: remove this check once the following RFC is resolved.
# TODO: remove this check once the following RFC is resolved.
# https://github.com/vllm-project/vllm/issues/33314
# https://github.com/vllm-project/vllm/issues/33314
#
This check is required because
Mxfp8OnlineLinearMethod
inherits from
#
Subclasses (e.g.
Mxfp8OnlineLinearMethod
) only need the weight
#
Fp8OnlineLinearMethod but only calls super().create_weights(), so we must
#
registration above and manage their own kernel, so skip fp8_linear
#
skip the fp8_linear
kernel creation.
# kernel creation
for them
.
if
hasattr
(
self
,
"mxfp8_linear"
)
:
if
type
(
self
)
is
not
Fp8OnlineLinearMethod
:
return
return
self
.
fp8_linear
=
init_fp8_linear_kernel
(
self
.
fp8_linear
=
init_fp8_linear_kernel
(
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
11e2375f
...
@@ -12,6 +12,7 @@ from vllm.config import get_current_vllm_config
...
@@ -12,6 +12,7 @@ from vllm.config import get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.kernels.linear
import
(
from
vllm.model_executor.kernels.linear
import
(
init_fp8_linear_kernel
,
init_fp8_linear_kernel
,
init_mxfp8_linear_kernel
,
init_nvfp4_linear_kernel
,
init_nvfp4_linear_kernel
,
)
)
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
...
@@ -70,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
...
@@ -70,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE
,
MXFP8_BLOCK_SIZE
,
MXFP8_SCALE_DTYPE
,
MXFP8_SCALE_DTYPE
,
MXFP8_VALUE_DTYPE
,
MXFP8_VALUE_DTYPE
,
Mxfp8LinearOp
,
mxfp8_e4m3_quantize
,
mxfp8_e4m3_quantize
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
@@ -1576,7 +1576,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1576,7 +1576,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
"Dynamic quantization is not supported."
"Dynamic quantization is not supported."
)
)
self
.
mxfp8_linear_op
=
Mxfp8LinearOp
()
self
.
kernel
=
init_mxfp8_linear_kernel
()
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -1658,7 +1658,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1658,7 +1658,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
f
" got
{
layer
.
weight_scale
.
dtype
}
"
f
" got
{
layer
.
weight_scale
.
dtype
}
"
)
)
self
.
mxfp8_linear_op
.
process_weights
(
layer
)
self
.
kernel
.
process_weights
_after_loading
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1666,16 +1666,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1666,16 +1666,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
mxfp8_linear_op
.
apply
(
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
workspace
=
getattr
(
layer
,
"workspace"
,
None
),
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
class
ModelOptMxFp8FusedMoE
(
FusedMoEMethodBase
):
class
ModelOptMxFp8FusedMoE
(
FusedMoEMethodBase
):
...
...
vllm/model_executor/layers/quantization/mxfp8.py
View file @
11e2375f
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.kernels.linear
import
init_mxfp8_linear_kernel
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoE
,
...
@@ -34,7 +35,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
...
@@ -34,7 +35,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
)
)
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
MXFP8_BLOCK_SIZE
,
Mxfp8LinearOp
,
mxfp8_e4m3_quantize
,
mxfp8_e4m3_quantize
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
@@ -126,8 +126,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
...
@@ -126,8 +126,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
def
__init__
(
self
,
quant_config
:
"Mxfp8Config"
):
def
__init__
(
self
,
quant_config
:
"Mxfp8Config"
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
kernel
=
init_mxfp8_linear_kernel
()
self
.
mxfp8_linear
=
Mxfp8LinearOp
()
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -166,7 +165,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
...
@@ -166,7 +165,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
replace_parameter
(
layer
,
"weight"
,
weight_fp8
.
data
)
replace_parameter
(
layer
,
"weight"
,
weight_fp8
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
self
.
mxfp8_linear
.
process_weights
(
layer
)
self
.
kernel
.
process_weights
_after_loading
(
layer
)
layer
.
_already_called_process_weights_after_loading
=
True
layer
.
_already_called_process_weights_after_loading
=
True
...
@@ -176,16 +175,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
...
@@ -176,16 +175,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
mxfp8_linear
.
apply
(
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
bias
=
bias
,
workspace
=
getattr
(
layer
,
"workspace"
,
None
),
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
class
Mxfp8OnlineMoEMethod
(
Fp8OnlineMoEMethod
):
class
Mxfp8OnlineMoEMethod
(
Fp8OnlineMoEMethod
):
...
...
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
View file @
11e2375f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.utils
import
flashinfer
as
vllm_flashinfer
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
class
Mxfp8LinearBackend
(
Enum
):
EMULATION
=
"emulation"
FLASHINFER_CUTLASS
=
"flashinfer-cutlass"
MARLIN
=
"marlin"
# MXFP8 constants
# MXFP8 constants
MXFP8_VALUE_DTYPE
=
torch
.
float8_e4m3fn
MXFP8_VALUE_DTYPE
=
torch
.
float8_e4m3fn
MXFP8_SCALE_DTYPE
=
torch
.
uint8
MXFP8_SCALE_DTYPE
=
torch
.
uint8
MXFP8_BLOCK_SIZE
=
32
MXFP8_BLOCK_SIZE
=
32
def
select_mxfp8_linear_backend
()
->
Mxfp8LinearBackend
:
"""Select the best MXFP8 linear backend for the current device.
- SM100+ (Blackwell): FLASHINFER_CUTLASS (native MXFP8 W8A8 GEMM)
- SM80+ (Ampere/Ada): MARLIN (MXFP8 W8A16 GEMM)
- Otherwise: EMULATION (dequant to BF16 fallback)
"""
from
vllm.platforms
import
current_platform
if
current_platform
.
has_device_capability
(
100
):
return
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
is_fp8_marlin_supported
,
)
if
is_fp8_marlin_supported
():
return
Mxfp8LinearBackend
.
MARLIN
return
Mxfp8LinearBackend
.
EMULATION
def
swizzle_mxfp8_scale
(
sf
:
torch
.
Tensor
,
M
:
int
,
K
:
int
)
->
torch
.
Tensor
:
def
swizzle_mxfp8_scale
(
sf
:
torch
.
Tensor
,
M
:
int
,
K
:
int
)
->
torch
.
Tensor
:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size
=
MXFP8_BLOCK_SIZE
# 32 for MXFP8
scaling_vector_size
=
MXFP8_BLOCK_SIZE
# 32 for MXFP8
...
@@ -209,194 +173,3 @@ def xpu_mxfp8_quantize(
...
@@ -209,194 +173,3 @@ def xpu_mxfp8_quantize(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
|
None
=
None
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
xpu_mxfp8_quantize
(
x
,
dtype
)
return
torch
.
ops
.
vllm
.
xpu_mxfp8_quantize
(
x
,
dtype
)
class
Mxfp8LinearOp
:
def
__init__
(
self
):
self
.
backend
=
select_mxfp8_linear_backend
()
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
self
.
backend
)
def
process_weights
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process MXFP8 weights after loading into backend-specific format."""
if
self
.
backend
==
Mxfp8LinearBackend
.
MARLIN
:
self
.
_process_weights_marlin
(
layer
)
elif
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
:
self
.
_process_weights_flashinfer_cutlass
(
layer
)
else
:
self
.
_process_weights_emulation
(
layer
)
def
_process_weights_emulation
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Keep scales as 2D uint8 for dequant-to-BF16 emulation."""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
_process_weights_flashinfer_cutlass
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Swizzle scales to F8_128x4 layout for flashinfer CUTLASS."""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale_2d
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
weight_scale_swizzled
=
swizzle_mxfp8_scale
(
weight_scale_2d
,
M
=
N
,
K
=
K
)
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale_swizzled
.
contiguous
(),
requires_grad
=
False
)
def
_process_weights_marlin
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Repack MXFP8 weights and scales into Marlin kernel format."""
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_mxfp8_layer_for_marlin
,
)
prepare_mxfp8_layer_for_marlin
(
layer
)
def
_apply_emulation
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
raise
ValueError
(
f
"TORCH backend requires
{
MXFP8_SCALE_DTYPE
}
weight_scale dtype, "
f
"got
{
weight_scale
.
dtype
}
."
)
if
weight_scale
.
ndim
!=
2
:
raise
ValueError
(
f
"TORCH backend requires 2D weight_scale, got
{
weight_scale
.
ndim
}
D. "
f
"Ensure process_weights_after_loading was called."
)
weight_bf16
=
dequant_mxfp8_to_bf16
(
weight
,
weight_scale
)
output
=
torch
.
nn
.
functional
.
linear
(
input
,
weight_bf16
,
bias
)
return
output
.
to
(
out_dtype
)
def
_apply_flashinfer_cutlass
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
N
,
K
=
weight
.
shape
input_shape
=
input
.
shape
input_2d
=
input
.
view
(
-
1
,
K
)
M_orig
=
input_2d
.
shape
[
0
]
# Minimum dimension size for F8_128x4 block scaling layout
min_dim
=
128
assert
min_dim
<=
K
,
(
f
"mm_mxfp8 requires K >=
{
min_dim
}
, got K=
{
K
}
. "
f
"in_features is too small for mm_mxfp8."
)
assert
K
%
MXFP8_BLOCK_SIZE
==
0
,
(
f
"mm_mxfp8 requires K to be divisible by
{
MXFP8_BLOCK_SIZE
}
, got K=
{
K
}
."
)
assert
min_dim
<=
N
,
(
f
"mm_mxfp8 requires N >=
{
min_dim
}
, got N=
{
N
}
. "
f
"out_features is too small for mm_mxfp8."
)
M_padded
=
((
M_orig
+
min_dim
-
1
)
//
min_dim
)
*
min_dim
if
M_padded
!=
M_orig
:
pad_rows
=
M_padded
-
M_orig
input_2d
=
torch
.
nn
.
functional
.
pad
(
input_2d
,
(
0
,
0
,
0
,
pad_rows
))
input_mxfp8
,
input_scale
=
mxfp8_e4m3_quantize
(
input_2d
,
is_sf_swizzled_layout
=
True
,
# Swizzled for best accuracy
)
if
not
weight
.
is_contiguous
():
weight
=
weight
.
contiguous
()
output
=
vllm_flashinfer
.
mm_mxfp8
(
input_mxfp8
,
weight
.
t
(),
input_scale
,
weight_scale
,
out_dtype
=
out_dtype
,
backend
=
"cutlass"
,
)
if
M_padded
!=
M_orig
:
output
=
output
[:
M_orig
,
:]
if
bias
is
not
None
:
output
=
output
+
bias
output_shape
=
(
*
input_shape
[:
-
1
],
N
)
return
output
.
view
(
output_shape
)
def
_apply_marlin
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
*
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_mxfp8_marlin_linear
,
)
return
apply_mxfp8_marlin_linear
(
input
=
input
,
weight
=
weight
,
weight_scale
=
weight_scale
,
workspace
=
workspace
,
size_n
=
size_n
,
size_k
=
size_k
,
bias
=
bias
,
)
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
*
,
workspace
:
torch
.
Tensor
|
None
=
None
,
size_n
:
int
=
0
,
size_k
:
int
=
0
,
)
->
torch
.
Tensor
:
if
self
.
backend
==
Mxfp8LinearBackend
.
EMULATION
:
return
self
.
_apply_emulation
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
)
if
self
.
backend
==
Mxfp8LinearBackend
.
MARLIN
:
assert
workspace
is
not
None
return
self
.
_apply_marlin
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
,
workspace
=
workspace
,
size_n
=
size_n
,
size_k
=
size_k
,
)
assert
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
return
self
.
_apply_flashinfer_cutlass
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment