Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
148117ea
Unverified
Commit
148117ea
authored
Jan 20, 2026
by
vllmellm
Committed by
GitHub
Jan 20, 2026
Browse files
[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
e9c83cdc
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
487 additions
and
469 deletions
+487
-469
vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py
...cutor/layers/quantization/kernels/scaled_mm/flashinfer.py
+57
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
...executor/layers/quantization/kernels/scaled_mm/pytorch.py
+221
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
...el_executor/layers/quantization/kernels/scaled_mm/rocm.py
+117
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
..._executor/layers/quantization/kernels/scaled_mm/triton.py
+23
-18
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+18
-19
vllm/model_executor/layers/quantization/ptpc_fp8.py
vllm/model_executor/layers/quantization/ptpc_fp8.py
+10
-13
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+28
-19
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
...utor/layers/quantization/quark/schemes/quark_w8a8_int8.py
+10
-22
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+3
-0
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+0
-378
No files found.
vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py
0 → 100644
View file @
148117ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
flashinfer_scaled_fp8_mm
,
has_flashinfer
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
class
FlashInferFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"requires CUDA."
if
not
has_flashinfer
():
return
False
,
"requires FlashInfer to be installed."
if
compute_capability
is
not
None
and
compute_capability
<
100
:
return
False
,
"requires compute capability 100 and above."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
per_tensor_activation_scales
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
)
per_tensor_weight_scales
=
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
if
not
(
per_tensor_activation_scales
and
per_tensor_weight_scales
):
return
False
,
"requires per tensor activation and weight scales."
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
return
flashinfer_scaled_fp8_mm
(
A
,
B
,
out_dtype
=
out_dtype
,
scale_a
=
As
,
scale_b
=
Bs
,
bias
=
bias
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
0 → 100644
View file @
148117ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
packaging
import
version
from
vllm.config
import
CompilationMode
,
get_current_vllm_config
from
vllm.platforms
import
current_platform
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
class
TorchFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
"""
Base class for FP8 linear kernels using Torch.
Each subclass represents a kernel variant for
specific device capabilities and torch versions.
"""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
(
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_cpu
()):
return
False
,
"requires ROCm, CUDA or CPU."
if
compute_capability
is
not
None
and
compute_capability
<
89
:
return
False
,
"requires compute capability 89 and above."
return
True
,
None
def
get_output_padding
(
self
)
->
int
|
None
:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
#
# The perf gain is still relevant as of 16/1/2026
# torch version == 2.9.0. More details in the link below:
# https://github.com/vllm-project/vllm/issues/32269
vllm_config
=
get_current_vllm_config
().
compilation_config
pad_output
=
vllm_config
.
mode
<
CompilationMode
.
VLLM_COMPILE
return
17
if
pad_output
else
None
class
PerTensorTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
per_tensor_activation_scales
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
)
per_tensor_weight_scales
=
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
if
not
(
per_tensor_activation_scales
and
per_tensor_weight_scales
):
return
False
,
"requires per tensor activation and weight scales."
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
output
=
torch
.
_scaled_mm
(
A
,
B
,
out_dtype
=
out_dtype
,
scale_a
=
As
,
scale_b
=
Bs
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
]).
view
(
*
output_shape
)
class
RowWiseTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
return
False
,
"requires ROCm."
from
vllm.platforms.rocm
import
on_mi3xx
if
not
on_mi3xx
():
return
False
,
"requires MI3xx."
if
compute_capability
is
not
None
and
compute_capability
<
94
:
return
False
,
"requires compute capability 94 and above."
if
not
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"2.7"
):
return
False
,
"requires pytorch version >=2.7."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
per_tensor_activation_scales
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
)
per_tensor_weight_scales
=
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
if
c
.
out_dtype
==
torch
.
float16
:
# hipblaslt rowwise _scaled_mm only supports BFloat16
return
False
,
"supports BFloat16 output data type only."
if
per_tensor_activation_scales
or
per_tensor_weight_scales
:
return
False
,
"cannot be used with per tensor activation and weight scales."
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
# Note:
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
A
,
B
,
out_dtype
=
out_dtype
,
scale_a
=
As
,
scale_b
=
Bs
.
t
(),
bias
=
bias
,
)
return
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
]).
view
(
*
output_shape
)
class
ChannelWiseTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
per_tensor_activation_scales
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
)
per_tensor_weight_scales
=
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
if
per_tensor_activation_scales
and
per_tensor_weight_scales
:
return
False
,
"cannot be used with per tensor activation and weight scales."
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
A
,
B
,
scale_a
=
dummy_tensor
,
scale_b
=
dummy_tensor
,
out_dtype
=
torch
.
float32
,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
])
x_scale
=
torch
.
narrow
(
As
,
0
,
0
,
output_shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
Bs
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
0 → 100644
View file @
148117ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils.platform_utils
import
get_cu_count
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
def
rocm_per_tensor_float_w8a8_scaled_mm_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
A
.
shape
[
0
]
==
1
and
B
.
shape
[
1
]
%
16
==
0
and
((
bias
is
None
)
or
(
bias
.
dtype
==
out_dtype
))
):
output
=
ops
.
wvSplitKQ
(
B
.
t
(),
A
,
out_dtype
,
As
,
Bs
,
get_cu_count
(),
bias
,
)
# Fallback
else
:
output
=
torch
.
_scaled_mm
(
A
,
B
,
out_dtype
=
out_dtype
,
scale_a
=
As
,
scale_b
=
Bs
,
bias
=
bias
,
)
return
output
def
rocm_per_tensor_float_w8a8_scaled_mm_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
A
.
new_empty
((
*
A
.
shape
[:
-
1
],
B
.
shape
[
1
]),
dtype
=
out_dtype
)
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
op_name
=
"rocm_per_tensor_float_w8a8_scaled_mm_impl"
,
op_func
=
rocm_per_tensor_float_w8a8_scaled_mm_impl
,
fake_impl
=
rocm_per_tensor_float_w8a8_scaled_mm_fake
,
)
class
ROCmFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
return
False
,
"requires ROCm."
from
vllm.platforms.rocm
import
on_mi3xx
if
not
on_mi3xx
():
return
False
,
"requires MI3xx."
if
not
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
:
return
False
,
"requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
per_tensor_activation_scales
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
)
per_tensor_weight_scales
=
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
if
not
(
per_tensor_activation_scales
and
per_tensor_weight_scales
):
return
False
,
"requires per tensor activation and weight scales."
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
output
=
torch
.
ops
.
vllm
.
rocm_per_tensor_float_w8a8_scaled_mm_impl
(
A
,
B
,
out_dtype
,
As
,
Bs
,
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
A
.
shape
[
0
]).
view
(
*
output_shape
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
View file @
148117ea
...
@@ -14,30 +14,35 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -14,30 +14,35 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.ScaledMMLinearKernel
import
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
from
.cutlass
import
CutlassInt8ScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
(
Int8ScaledMMLinearLayerConfig
,
)
class
TritonScaledMMLinearKernel
(
ScaledMMLinearKernel
):
class
Triton
Int8
ScaledMMLinearKernel
(
CutlassInt8
ScaledMMLinearKernel
):
@
classmethod
@
classmethod
def
is_supported
(
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
)
->
tuple
[
bool
,
str
|
None
]:
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
return
True
,
None
return
True
,
None
return
False
,
"
R
equires ROCm or CUDA."
return
False
,
"
r
equires ROCm or CUDA."
@
classmethod
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
c
.
input_symmetric
:
if
not
c
.
input_symmetric
:
return
False
,
"
Only
symmetric input
is supported
."
return
False
,
"
supports
symmetric input
only
."
return
True
,
None
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
getattr
(
layer
,
self
.
w_q_name
)
w_q
,
_
,
i_s
,
_
,
_
=
self
.
_get_layer_params
(
layer
)
w_q_name
,
w_s_name
,
i_s_name
,
i_zp_name
,
azp_adj_name
=
self
.
layer_param_names
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_q_name
,
w_q_name
,
torch
.
nn
.
Parameter
(
w
eight
.
t
().
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
w
_q
.
t
().
data
,
requires_grad
=
False
),
)
)
# WEIGHT SCALE
# WEIGHT SCALE
...
@@ -45,29 +50,29 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -45,29 +50,29 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
)
# INPUT SCALE
# INPUT SCALE
if
self
.
config
.
is_static_input_scheme
:
if
self
.
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
self
.
i_s_name
)
assert
i_s
is
not
None
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
i_s_name
,
i_s_name
,
torch
.
nn
.
Parameter
(
i
nput_scale
.
max
(),
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
i
_s
.
max
(),
requires_grad
=
False
),
)
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
i_zp_name
,
None
)
else
:
else
:
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
i_zp_name
,
None
)
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
setattr
(
layer
,
azp_adj_name
,
None
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
...
@@ -75,7 +80,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -75,7 +80,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
_
=
self
.
_get_
layer
_params
(
layer
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
.
contiguous
(),
i_s
,
i_zp
,
symmetric
=
True
x
.
contiguous
(),
i_s
,
i_zp
,
symmetric
=
True
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
148117ea
...
@@ -49,6 +49,9 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -49,6 +49,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
...
@@ -78,10 +81,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -78,10 +81,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape
,
GroupShape
,
cutlass_fp4_supported
,
cutlass_fp4_supported
,
is_layer_skipped
,
is_layer_skipped
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
swizzle_blockscale
,
swizzle_blockscale
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
...
@@ -438,8 +443,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -438,8 +443,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
)
->
None
:
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
fp8_linear
=
Fp8LinearOp
(
self
.
fp8_linear
=
init_fp8_linear_kernel
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
activation_quant_key
=
kFp8StaticTensorSym
,
weight_quant_key
=
kFp8StaticTensorSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
)
def
create_weights
(
def
create_weights
(
...
@@ -507,13 +515,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -507,13 +515,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
class
ModelOptFp8PcPtLinearMethod
(
LinearMethodBase
):
class
ModelOptFp8PcPtLinearMethod
(
LinearMethodBase
):
...
@@ -527,8 +529,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
...
@@ -527,8 +529,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
)
->
None
:
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
fp8_linear
=
Fp8LinearOp
(
self
.
fp8_linear
=
init_fp8_linear_kernel
(
act_quant_static
=
False
,
act_quant_group_shape
=
GroupShape
.
PER_TOKEN
activation_quant_key
=
kFp8DynamicTokenSym
,
weight_quant_key
=
kFp8StaticTokenSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
)
def
create_weights
(
def
create_weights
(
...
@@ -585,13 +590,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
...
@@ -585,13 +590,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
bias
=
bias
,
)
class
ModelOptFp8PbWoLinearMethod
(
LinearMethodBase
):
class
ModelOptFp8PbWoLinearMethod
(
LinearMethodBase
):
...
...
vllm/model_executor/layers/quantization/ptpc_fp8.py
View file @
148117ea
...
@@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import (
...
@@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod
,
Fp8KVCacheMethod
,
Fp8LinearMethod
,
Fp8LinearMethod
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
is_layer_skipped
,
is_layer_skipped
,
kFp8DynamicTokenSym
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
@@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
...
@@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
)
)
super
().
__init__
(
quant_config
=
quant_config
)
super
().
__init__
(
quant_config
=
quant_config
)
# Force weight quantization
# Force weight quantization
self
.
quant_config
.
is_checkpoint_fp8_serialized
=
False
self
.
fp8_linear
=
init_fp8_linear_kernel
(
self
.
fp8_linear
=
Fp8LinearOp
(
activation_quant_key
=
kFp8DynamicTokenSym
,
act_quant_static
=
False
,
act_quant_group_shape
=
GroupShape
.
PER_TOKEN
weight_quant_key
=
kFp8DynamicTokenSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
@@ -130,11 +134,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
...
@@ -130,11 +134,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
None
,
bias
=
bias
,
)
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
148117ea
...
@@ -7,10 +7,18 @@ from typing import Any, cast
...
@@ -7,10 +7,18 @@ from typing import Any, cast
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.quark.schemes
import
QuarkScheme
from
vllm.model_executor.layers.quantization.quark.schemes
import
QuarkScheme
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
...
@@ -23,6 +31,8 @@ from vllm.platforms import current_platform
...
@@ -23,6 +31,8 @@ from vllm.platforms import current_platform
__all__
=
[
"QuarkW8A8Fp8"
]
__all__
=
[
"QuarkW8A8Fp8"
]
logger
=
init_logger
(
__name__
)
class
QuarkW8A8Fp8
(
QuarkScheme
):
class
QuarkW8A8Fp8
(
QuarkScheme
):
def
__init__
(
def
__init__
(
...
@@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
self
.
is_static_input_scheme
=
not
cast
(
bool
,
input_config
.
get
(
"is_dynamic"
))
self
.
is_static_input_scheme
=
not
cast
(
bool
,
input_config
.
get
(
"is_dynamic"
))
self
.
input_qscheme
=
cast
(
str
,
input_config
.
get
(
"qscheme"
))
self
.
input_qscheme
=
cast
(
str
,
input_config
.
get
(
"qscheme"
))
per_token
=
(
per_token
_activation
=
(
not
self
.
is_static_input_scheme
and
self
.
input_qscheme
==
"per_channel"
not
self
.
is_static_input_scheme
and
self
.
input_qscheme
==
"per_channel"
)
)
self
.
act_quant_group_shape
=
(
per_token_weight
=
self
.
weight_qscheme
==
"per_channel"
GroupShape
.
PER_TOKEN
if
per_token
else
GroupShape
.
PER_TENSOR
self
.
activation_quant_key
=
(
kFp8DynamicTokenSym
if
per_token_activation
else
kFp8StaticTensorSym
)
)
self
.
fp8_linear
=
Fp8LinearOp
(
self
.
weight_quant_key
=
(
act_quant_static
=
self
.
is_static_input_scheme
,
kFp8StaticTokenSym
if
per_token_weight
else
kFp8StaticTensorSym
act_quant_group_shape
=
self
.
act_quant_group_shape
,
)
)
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
out_dtype
=
torch
.
get_default_dtype
()
...
@@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
else
:
else
:
weight_scale
=
layer
.
weight_scale
.
data
weight_scale
=
layer
.
weight_scale
.
data
if
self
.
act
_quant_
group_shape
==
GroupShape
.
PER_TOKEN
:
if
self
.
act
ivation_quant_key
.
scale
.
group_shape
==
GroupShape
.
PER_TOKEN
:
weight_scale
=
weight_scale
.
view
(
-
1
,
1
)
weight_scale
=
weight_scale
.
view
(
-
1
,
1
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
# required by torch.compile to be torch.nn.Parameter
...
@@ -106,8 +117,6 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -106,8 +117,6 @@ class QuarkW8A8Fp8(QuarkScheme):
# INPUT SCALE
# INPUT SCALE
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -163,17 +172,17 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -163,17 +172,17 @@ class QuarkW8A8Fp8(QuarkScheme):
input_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
input_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
self
.
activation_quant_key
,
weight_quant_key
=
self
.
weight_quant_key
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
View file @
148117ea
...
@@ -7,8 +7,7 @@ import torch
...
@@ -7,8 +7,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
ScaledMMLinearLayerConfig
,
init_int8_linear_kernel
,
choose_scaled_mm_linear_kernel
,
)
)
from
vllm.model_executor.layers.quantization.quark.schemes
import
QuarkScheme
from
vllm.model_executor.layers.quantization.quark.schemes
import
QuarkScheme
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
...
@@ -22,8 +21,6 @@ logger = init_logger(__name__)
...
@@ -22,8 +21,6 @@ logger = init_logger(__name__)
class
QuarkW8A8Int8
(
QuarkScheme
):
class
QuarkW8A8Int8
(
QuarkScheme
):
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
def
__init__
(
self
,
self
,
qscheme
:
str
,
qscheme
:
str
,
...
@@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme):
...
@@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme):
):
):
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
s
caled_mm_linear_kernel_config
=
ScaledMMLinearLayerConfig
(
s
elf
.
kernel
=
init_int8_linear_kernel
(
is_channelwise
=
(
self
.
qscheme
==
"per_channel"
),
is_channelwise
=
(
self
.
qscheme
==
"per_channel"
),
is_static_input_scheme
=
(
self
.
is_static_input_scheme
is
True
),
is_static_input_scheme
=
(
self
.
is_static_input_scheme
is
True
),
input_symmetric
=
(
self
.
input_symmetric
is
True
),
input_symmetric
=
(
self
.
input_symmetric
is
True
),
module_name
=
self
.
__class__
.
__name__
,
)
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
scaled_mm_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for QuarkW8A8Int8"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# WEIGHT
# WEIGHT
weight
=
ModelWeightParameter
(
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
...
@@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme):
...
@@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme):
layer
.
register_parameter
(
"weight_zero_point"
,
weight_zero_point
)
layer
.
register_parameter
(
"weight_zero_point"
,
weight_zero_point
)
# INPUT SCALE
# INPUT SCALE
input_zero_point
=
None
input_scale
=
None
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
input_scale
=
BasevLLMParameter
(
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
input_zero_point
=
BasevLLMParameter
(
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
self
.
kernel
=
kernel_type
(
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
c
=
scaled_mm_linear_kernel_config
,
if
not
hasattr
(
layer
,
"azp_adj"
):
w_q_param_name
=
"weight"
,
layer
.
register_parameter
(
"azp_adj"
,
None
)
w_s_param_name
=
"weight_scale"
,
i_s_param_name
=
"input_scale"
,
i_zp_param_name
=
"input_zero_point"
,
azp_adj_param_name
=
"azp_adj"
,
)
# Checkpoints are serialized in quark format, which is
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
# different from the format the kernel may want. Handle repacking here.
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
148117ea
...
@@ -123,6 +123,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
...
@@ -123,6 +123,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
.
PER_TENSOR
)
kDynamicTensorScale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
.
PER_TENSOR
)
kFp8DynamicTensorSym
=
QuantKey
(
FP8_DTYPE
,
kDynamicTensorScale
,
symmetric
=
True
)
kFp8DynamicTensorSym
=
QuantKey
(
FP8_DTYPE
,
kDynamicTensorScale
,
symmetric
=
True
)
kStaticTokenScale
=
ScaleDesc
(
torch
.
float32
,
True
,
GroupShape
.
PER_TOKEN
)
kFp8StaticTokenSym
=
QuantKey
(
FP8_DTYPE
,
kStaticTokenScale
,
symmetric
=
True
)
kDynamicTokenScale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
.
PER_TOKEN
)
kDynamicTokenScale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
.
PER_TOKEN
)
kFp8DynamicTokenSym
=
QuantKey
(
FP8_DTYPE
,
kDynamicTokenScale
,
symmetric
=
True
)
kFp8DynamicTokenSym
=
QuantKey
(
FP8_DTYPE
,
kDynamicTokenScale
,
symmetric
=
True
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
148117ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
import
torch
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.config
import
CompilationMode
,
get_current_vllm_config
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
flashinfer_scaled_fp8_mm
,
has_flashinfer
from
vllm.utils.platform_utils
import
get_cu_count
from
vllm.utils.torch_utils
import
direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
None
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time-consuming.
USE_ROWWISE_TORCH_SCALED_MM
=
(
current_platform
.
is_rocm
()
and
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"2.7"
)
and
current_platform
.
has_device_capability
(
94
)
)
def
sparse_cutlass_supported
()
->
bool
:
def
sparse_cutlass_supported
()
->
bool
:
...
@@ -140,361 +117,6 @@ def requantize_with_max_scale(
...
@@ -140,361 +117,6 @@ def requantize_with_max_scale(
return
max_w_scale
,
weight
return
max_w_scale
,
weight
def
maybe_create_device_identity
():
# Allocate dummy ones tensor for torch._scaled_mm
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
def
cutlass_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
**
kwargs
,
)
->
torch
.
Tensor
:
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
def
flashinfer_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
flashinfer_scaled_fp8_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
)
def
rocm_per_tensor_w8a8_scaled_mm_impl
(
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
from
vllm.platforms.rocm
import
on_mi3xx
if
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi3xx
()
and
qinput
.
shape
[
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
and
((
bias
is
None
)
or
(
bias
.
dtype
==
out_dtype
))
):
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
get_cu_count
(),
bias
,
)
else
:
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
,
)
return
output
def
rocm_per_tensor_w8a8_scaled_mm_fake
(
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
qinput
.
new_empty
((
*
qinput
.
shape
[:
-
1
],
weight
.
shape
[
1
]),
dtype
=
out_dtype
)
def
rocm_per_tensor_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
output
=
torch
.
ops
.
vllm
.
rocm_per_tensor_w8a8_scaled_mm_impl
(
qinput
,
weight
,
out_dtype
,
scale_a
,
scale_b
,
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
qinput
.
shape
[
0
]).
view
(
*
output_shape
)
direct_register_custom_op
(
op_name
=
"rocm_per_tensor_w8a8_scaled_mm_impl"
,
op_func
=
rocm_per_tensor_w8a8_scaled_mm_impl
,
fake_impl
=
rocm_per_tensor_w8a8_scaled_mm_fake
,
)
def
torch_per_tensor_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
qinput
.
shape
[
0
]).
view
(
*
output_shape
)
def
torch_per_token_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
**
kwargs
,
)
->
torch
.
Tensor
:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
.
t
(),
bias
=
bias
,
)
output
=
torch
.
narrow
(
output
,
0
,
0
,
qinput
.
shape
[
0
])
output
=
output
.
view
(
*
output_shape
)
return
output
def
torch_channelwise_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
**
kwargs
,
)
->
torch
.
Tensor
:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
qinput
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
scale_a
,
0
,
0
,
qinput
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
scale_b
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
def
dispatch_w8a8_scaled_mm
(
preferred_backend
:
str
,
per_tensor_weights
:
bool
,
per_tensor_activations
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
per_tensor_weights
and
per_tensor_activations
:
if
preferred_backend
==
"rocm"
:
return
rocm_per_tensor_w8a8_scaled_mm
if
preferred_backend
==
"flashinfer"
:
return
flashinfer_w8a8_scaled_mm
if
preferred_backend
==
"cutlass"
:
return
cutlass_w8a8_scaled_mm
return
torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
preferred_backend
==
"cutlass"
or
preferred_backend
==
"flashinfer"
:
return
cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if
(
not
per_tensor_weights
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
):
return
torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
return
torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class
Fp8LinearOp
:
"""
This class executes a FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
It needs to be a class instead of a method so that config can be read
in the __init__ method, as reading config is not allowed inside forward.
"""
def
__init__
(
self
,
act_quant_static
:
bool
,
act_quant_group_shape
:
GroupShape
=
GroupShape
.
PER_TENSOR
,
pad_output
:
bool
|
None
=
None
,
):
if
current_platform
.
is_rocm
():
self
.
preferred_backend
=
"rocm"
elif
current_platform
.
is_cuda
()
and
cutlass_fp8_supported
():
if
has_flashinfer
()
and
current_platform
.
has_device_capability
(
100
):
self
.
preferred_backend
=
"flashinfer"
else
:
self
.
preferred_backend
=
"cutlass"
else
:
self
.
preferred_backend
=
"torch"
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if
pad_output
is
None
:
config
=
get_current_vllm_config
().
compilation_config
pad_output
=
(
config
.
mode
<
CompilationMode
.
VLLM_COMPILE
and
self
.
preferred_backend
==
"torch"
)
self
.
output_padding
=
17
if
pad_output
else
None
self
.
act_quant_static
=
act_quant_static
self
.
act_quant_group_shape
=
act_quant_group_shape
self
.
quant_fp8
=
QuantFP8
(
static
=
act_quant_static
,
group_shape
=
act_quant_group_shape
,
num_token_padding
=
self
.
output_padding
,
)
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
|
None
=
None
,
input_scale
:
torch
.
Tensor
|
None
=
None
,
input_scale_ub
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
1
]]
if
out_dtype
is
None
:
out_dtype
=
input
.
dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
if
input
.
dtype
!=
current_platform
.
fp8_dtype
():
qinput
,
x_scale
=
self
.
quant_fp8
(
input_2d
,
input_scale
,
input_scale_ub
,
)
else
:
qinput
,
x_scale
=
input_2d
,
input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights
=
weight_scale
.
numel
()
==
1
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
and
x_scale
.
dim
()
<
2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func
=
dispatch_w8a8_scaled_mm
(
self
.
preferred_backend
,
per_tensor_weights
,
per_tensor_activations
)
return
w8a8_scaled_mm_func
(
qinput
=
qinput
,
weight
=
weight
,
out_dtype
=
out_dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
output_shape
=
output_shape
,
)
def
normalize_e4m3fn_to_e4m3fnuz
(
def
normalize_e4m3fn_to_e4m3fnuz
(
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment