Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
94c8a620
Commit
94c8a620
authored
Jan 21, 2026
by
wanglong3
Committed by
zhuwenwen
Jan 21, 2026
Browse files
feat: Supprot fp8 channle-wise matmul.
parent
4dcfd0ae
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
71 deletions
+109
-71
vllm/_custom_ops.py
vllm/_custom_ops.py
+59
-61
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+3
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+46
-9
No files found.
vllm/_custom_ops.py
View file @
94c8a620
...
@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op
...
@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
from
lmslim
import
quant_tools
from
lmslim.layers.gemm.fp8_utils
import
per_token_quant_fp8
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
try
:
...
@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant(
...
@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant(
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
return
output
,
output_scales
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
# fp8
This function supports both static and dynamic quantization: If you
# def scaled_fp8_quant(
provide the scale, it will use static scaling and if you omit it,
# input: torch.Tensor,
the scale will be determined dynamically. The function also allows
# scale: Optional[torch.Tensor] = None,
optional padding of the output tensors for downstream kernels that
# num_token_padding: Optional[int] = None,
will benefit from padding.
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input, scale, scale_ub)
# else:
# scale = torch.empty(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
if
output
is
None
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
else
:
assert
num_token_padding
is
None
,
\
"padding not supported if output passed in"
assert
output
.
dtype
==
out_dtype
if
scale
is
None
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
# gptq allspark
# gptq allspark
def
allspark_repack_weight
(
def
allspark_repack_weight
(
qweight
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
94c8a620
...
@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
94c8a620
...
@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def
apply_weights
(
self
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
if
layer
.
weight_block_size
is
not
None
:
if
layer
.
weight_block_size
is
not
None
:
return
apply_fp8_block_linear
(
return
apply_fp8_block_linear
(
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
94c8a620
...
@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
...
@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
try
:
try
:
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.quantize
import
quant_ops
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
@@ -291,7 +292,37 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
...
@@ -291,7 +292,37 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output
=
output
.
view
(
*
output_shape
)
output
=
output
.
view
(
*
output_shape
)
return
output
return
output
def
hipblaslt_w8a8_channelwise_scaled_mm
(
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
assert
qinput
.
is_contiguous
()
and
weight
.
is_contiguous
()
assert
qinput
.
shape
[
-
1
]
==
weight
.
shape
[
-
1
]
assert
qinput
.
dtype
==
weight
.
dtype
m
=
qinput
.
shape
[
0
]
k
=
qinput
.
shape
[
1
]
n
=
weight
.
shape
[
0
]
success
,
output
=
quant_ops
.
hipblaslt_w8a8_channelwise_gemm
(
a
=
qinput
,
b
=
weight
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
m
=
m
,
n
=
n
,
k
=
k
,
transpose_flag
=
"NT"
,
out_dtype
=
out_dtype
,
bias
=
bias
,
)
return
output
.
view
(
m
,
n
)
def
torch_channelwise_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
def
torch_channelwise_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
...
@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
...
@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
output
=
output
+
bias
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
def
dispatch_w8a8_scaled_mm
(
def
dispatch_w8a8_scaled_mm
(
preferred_backend
:
str
,
per_tensor_weights
:
bool
,
preferred_backend
:
str
,
per_tensor_weights
:
bool
,
per_tensor_activations
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
per_tensor_activations
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
per_tensor_weights
and
per_tensor_activations
:
if
per_tensor_weights
and
per_tensor_activations
:
if
preferred_backend
==
"rocm"
:
if
preferred_backend
==
"rocm"
:
return
rocm_per_tensor_w8a8_scaled_mm
return
rocm_per_tensor_w8a8_scaled_mm
...
@@ -353,6 +382,9 @@ def dispatch_w8a8_scaled_mm(
...
@@ -353,6 +382,9 @@ def dispatch_w8a8_scaled_mm(
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
preferred_backend
==
"cutlass"
or
preferred_backend
==
"flashinfer"
:
if
preferred_backend
==
"cutlass"
or
preferred_backend
==
"flashinfer"
:
return
cutlass_w8a8_scaled_mm
return
cutlass_w8a8_scaled_mm
if
preferred_backend
==
"blaslt"
:
return
hipblaslt_w8a8_channelwise_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if
not
per_tensor_weights
and
not
per_tensor_activations
\
if
not
per_tensor_weights
and
not
per_tensor_activations
\
...
@@ -378,7 +410,11 @@ class Fp8LinearOp:
...
@@ -378,7 +410,11 @@ class Fp8LinearOp:
act_quant_group_shape
:
GroupShape
=
GroupShape
.
PER_TENSOR
,
act_quant_group_shape
:
GroupShape
=
GroupShape
.
PER_TENSOR
,
pad_output
:
Optional
[
bool
]
=
None
):
pad_output
:
Optional
[
bool
]
=
None
):
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
self
.
preferred_backend
=
"rocm"
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
self
.
preferred_backend
=
"blaslt"
else
:
self
.
preferred_backend
=
"rocm"
elif
current_platform
.
is_cuda
()
and
cutlass_fp8_supported
():
elif
current_platform
.
is_cuda
()
and
cutlass_fp8_supported
():
if
has_flashinfer
()
and
current_platform
.
has_device_capability
(
if
has_flashinfer
()
and
current_platform
.
has_device_capability
(
100
):
100
):
...
@@ -429,11 +465,12 @@ class Fp8LinearOp:
...
@@ -429,11 +465,12 @@ class Fp8LinearOp:
# If input not quantized
# If input not quantized
# TODO(luka) remove this path if not used anymore
# TODO(luka) remove this path if not used anymore
if
input
.
dtype
!=
current_platform
.
fp8_dtype
():
if
input
.
dtype
!=
current_platform
.
fp8_dtype
():
qinput
,
x_scale
=
self
.
quant_fp8
(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_2d
,
input
=
input_2d
,
input_scale
,
scale
=
input_scale
,
input_scale_ub
,
num_token_padding
=
self
.
output_padding
,
)
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
True
)
else
:
else
:
qinput
,
x_scale
=
input_2d
,
input_scale
qinput
,
x_scale
=
input_2d
,
input_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment