Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
94c8a620
Commit
94c8a620
authored
Jan 21, 2026
by
wanglong3
Committed by
zhuwenwen
Jan 21, 2026
Browse files
feat: Supprot fp8 channle-wise matmul.
parent
4dcfd0ae
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
71 deletions
+109
-71
vllm/_custom_ops.py
vllm/_custom_ops.py
+59
-61
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+3
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+46
-9
No files found.
vllm/_custom_ops.py
View file @
94c8a620
...
...
@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
from
lmslim.layers.gemm.fp8_utils
import
per_token_quant_fp8
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
...
...
@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant(
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# num_token_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input, scale, scale_ub)
# else:
# scale = torch.empty(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
if
output
is
None
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
else
:
assert
num_token_padding
is
None
,
\
"padding not supported if output passed in"
assert
output
.
dtype
==
out_dtype
if
scale
is
None
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
# gptq allspark
def
allspark_repack_weight
(
qweight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
94c8a620
...
...
@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
94c8a620
...
...
@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
if
layer
.
weight_block_size
is
not
None
:
return
apply_fp8_block_linear
(
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
94c8a620
...
...
@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
try
:
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.quantize
import
quant_ops
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -291,7 +292,37 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output
=
output
.
view
(
*
output_shape
)
return
output
def
hipblaslt_w8a8_channelwise_scaled_mm
(
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
assert
qinput
.
is_contiguous
()
and
weight
.
is_contiguous
()
assert
qinput
.
shape
[
-
1
]
==
weight
.
shape
[
-
1
]
assert
qinput
.
dtype
==
weight
.
dtype
m
=
qinput
.
shape
[
0
]
k
=
qinput
.
shape
[
1
]
n
=
weight
.
shape
[
0
]
success
,
output
=
quant_ops
.
hipblaslt_w8a8_channelwise_gemm
(
a
=
qinput
,
b
=
weight
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
m
=
m
,
n
=
n
,
k
=
k
,
transpose_flag
=
"NT"
,
out_dtype
=
out_dtype
,
bias
=
bias
,
)
return
output
.
view
(
m
,
n
)
def
torch_channelwise_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
...
...
@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
def
dispatch_w8a8_scaled_mm
(
preferred_backend
:
str
,
per_tensor_weights
:
bool
,
per_tensor_activations
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
per_tensor_weights
and
per_tensor_activations
:
if
preferred_backend
==
"rocm"
:
return
rocm_per_tensor_w8a8_scaled_mm
...
...
@@ -353,6 +382,9 @@ def dispatch_w8a8_scaled_mm(
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
preferred_backend
==
"cutlass"
or
preferred_backend
==
"flashinfer"
:
return
cutlass_w8a8_scaled_mm
if
preferred_backend
==
"blaslt"
:
return
hipblaslt_w8a8_channelwise_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if
not
per_tensor_weights
and
not
per_tensor_activations
\
...
...
@@ -378,7 +410,11 @@ class Fp8LinearOp:
act_quant_group_shape
:
GroupShape
=
GroupShape
.
PER_TENSOR
,
pad_output
:
Optional
[
bool
]
=
None
):
if
current_platform
.
is_rocm
():
self
.
preferred_backend
=
"rocm"
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
self
.
preferred_backend
=
"blaslt"
else
:
self
.
preferred_backend
=
"rocm"
elif
current_platform
.
is_cuda
()
and
cutlass_fp8_supported
():
if
has_flashinfer
()
and
current_platform
.
has_device_capability
(
100
):
...
...
@@ -429,11 +465,12 @@ class Fp8LinearOp:
# If input not quantized
# TODO(luka) remove this path if not used anymore
if
input
.
dtype
!=
current_platform
.
fp8_dtype
():
qinput
,
x_scale
=
self
.
quant_fp8
(
input_2d
,
input_scale
,
input_scale_ub
,
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
=
input_2d
,
scale
=
input_scale
,
num_token_padding
=
self
.
output_padding
,
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
True
)
else
:
qinput
,
x_scale
=
input_2d
,
input_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment