Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0636f239
Commit
0636f239
authored
Feb 10, 2026
by
lixh6
Browse files
feat:适配Blaslt Channelwise gemm
parent
440222e9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
49 deletions
+105
-49
vllm/_custom_ops.py
vllm/_custom_ops.py
+62
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+8
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-1
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
...executor/layers/quantization/kernels/scaled_mm/pytorch.py
+33
-44
No files found.
vllm/_custom_ops.py
View file @
0636f239
...
@@ -19,6 +19,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
...
@@ -19,6 +19,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
from
lmslim
import
quant_tools
from
lmslim.layers.gemm.fp8_utils
import
per_token_quant_fp8
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
try
:
...
@@ -1878,6 +1879,67 @@ def scaled_fp4_experts_quant(
...
@@ -1878,6 +1879,67 @@ def scaled_fp4_experts_quant(
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
return
output
,
output_scales
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
group_shape
:
Optional
[
tuple
[
int
,
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
if
output
is
None
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
else
:
assert
num_token_padding
is
None
,
\
"padding not supported if output passed in"
assert
output
.
dtype
==
out_dtype
if
scale
is
None
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
def
silu_and_mul_scaled_fp4_experts_quant
(
def
silu_and_mul_scaled_fp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
0636f239
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Optional
from
vllm
import
envs
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationArgs
,
QuantizationStrategy
from
compressed_tensors.quantization
import
QuantizationArgs
,
QuantizationStrategy
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
...
@@ -40,7 +41,6 @@ from vllm.model_executor.parameter import (
...
@@ -40,7 +41,6 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
)
)
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
strategy_to_parameter_type
=
{
strategy_to_parameter_type
=
{
...
@@ -159,8 +159,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -159,8 +159,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_channel_strategy
(
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_channel_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
"input_scale"
,
None
)
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
"input_scale"
,
None
)
)
)
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
t
().
contiguous
()
else
:
weight
=
weight
.
t
()
weight
=
weight
.
t
()
elif
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
elif
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
is_static_input_scheme
is
False
assert
self
.
is_static_input_scheme
is
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
...
@@ -193,6 +195,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -193,6 +195,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
weight_block_size
is
not
None
:
if
self
.
weight_block_size
is
not
None
:
return
self
.
w8a8_block_fp8_linear
.
apply
(
return
self
.
w8a8_block_fp8_linear
.
apply
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
0636f239
...
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
...
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.utils._python_dispatch
import
TorchDispatchMode
from
torch.utils._python_dispatch
import
TorchDispatchMode
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -1027,6 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1027,6 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
kernel
is
not
None
assert
self
.
kernel
is
not
None
assert
not
self
.
is_monolithic
assert
not
self
.
is_monolithic
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
View file @
0636f239
...
@@ -12,7 +12,11 @@ from .ScaledMMLinearKernel import (
...
@@ -12,7 +12,11 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
FP8ScaledMMLinearLayerConfig
,
)
)
try
:
from
lmslim.quantize.quant_ops
import
hipblaslt_w8a8_channelwise_gemm
except
ImportError
:
print
(
"INFO: Please updata lmslim if you want to use fp8_utils.
\n
"
)
from
vllm
import
envs
class
TorchFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
class
TorchFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
"""
"""
...
@@ -176,46 +180,31 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
...
@@ -176,46 +180,31 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
bias
:
torch
.
Tensor
|
None
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Use unfused DQ due to limitations with scaled_mm
m
=
A
.
shape
[
0
]
k
=
A
.
shape
[
1
]
# Symmetric quantized GEMM by definition computes the following:
n
=
B
.
shape
[
0
]
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
# before applying a GEMM.
_
,
output
=
hipblaslt_w8a8_channelwise_gemm
(
#
a
=
A
,
# In order to compute quantized operands, a quantized kernel
b
=
B
,
# will rewrite the above like so:
scale_a
=
As
,
# C = s_w * s_x * (X * W) + bias
scale_b
=
Bs
,
#
m
=
m
,
# For the scaled_mm fallback case, we break this down, since it
n
=
n
,
# does not support s_w being a vector.
k
=
k
,
transpose_flag
=
"NT"
,
# Input scaling factors are no longer optional in _scaled_mm starting
out_dtype
=
out_dtype
,
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
bias
=
bias
,
dummy_tensor
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
)
return
output
.
view
(
m
,
n
)
# GEMM
else
:
# This computes C = (X * W).
output
=
triton_scaled_mm_fp8
(
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
A
,
A
,
B
,
B
,
scale_a
=
dummy_tensor
,
scale_a
=
As
,
scale_b
=
dummy_tensor
,
scale_b
=
Bs
,
out_dtype
=
torch
.
float32
,
out_dtype
=
out_dtype
,
bias
=
bias
,
)
)
# A fix for discrepancy in scaled_mm which returns tuple
return
output
.
view
(
*
output_shape
)
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
])
x_scale
=
torch
.
narrow
(
As
,
0
,
0
,
output_shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
Bs
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment