Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6227c22
Unverified
Commit
f6227c22
authored
Dec 08, 2025
by
czhu-cohere
Committed by
GitHub
Dec 08, 2025
Browse files
[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
ea657f20
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
5 deletions
+64
-5
vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
...or/layers/quantization/kernels/mixed_precision/cutlass.py
+15
-4
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+49
-1
No files found.
vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
View file @
f6227c22
...
@@ -6,7 +6,11 @@ import torch
...
@@ -6,7 +6,11 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
convert_bf16_scales_to_fp8
,
convert_packed_uint4b8_to_signed_int4_inplace
,
)
from
vllm.model_executor.parameter
import
BasevLLMParameter
,
permute_param_layout_
from
vllm.model_executor.parameter
import
BasevLLMParameter
,
permute_param_layout_
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
...
@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
...
@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
"CUTLASS W4A8, only supported int4"
,
"CUTLASS W4A8, only supported int4"
,
)
)
# TODO(czhu): support -1 (column-wise)
if
c
.
group_size
!=
128
:
if
c
.
group_size
!=
128
:
return
False
,
"Only group_size 128 is supported"
return
False
,
"Only group_size 128 is supported"
...
@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
...
@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
# TODO(czhu): optimize speed/mem usage
def
transform_w_q
(
x
):
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
assert
isinstance
(
x
,
BasevLLMParameter
)
convert_packed_uint4b8_to_signed_int4_inplace
(
x
.
data
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x
.
data
=
ops
.
cutlass_encode_and_reorder_int4b
(
x
.
data
.
t
().
contiguous
().
t
())
x
.
data
=
ops
.
cutlass_encode_and_reorder_int4b
(
x
.
data
.
t
().
contiguous
().
t
())
return
x
return
x
...
@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
...
@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
x
.
data
=
ops
.
cutlass_pack_scale_fp8
(
x
.
data
)
x
.
data
=
ops
.
cutlass_pack_scale_fp8
(
x
.
data
)
return
x
return
x
w_s
=
getattr
(
layer
,
self
.
w_s_name
)
fp8_scales
,
chan_scales
=
convert_bf16_scales_to_fp8
(
self
.
quant_fp8
,
w_s
.
data
)
w_s
.
data
=
fp8_scales
# register per-channel scales
layer
.
register_parameter
(
"weight_chan_scale"
,
torch
.
nn
.
Parameter
(
chan_scales
,
requires_grad
=
False
)
)
# Encode/reorder weights and pack scales
# Encode/reorder weights and pack scales
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
self
.
_transform_param
(
layer
,
"weight_chan_scale"
,
lambda
x
:
x
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
f6227c22
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks"""
"""This file is used for /tests and /benchmarks"""
from
collections.abc
import
Mapping
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
types
import
MappingProxyType
from
types
import
MappingProxyType
from
typing
import
ClassVar
,
NamedTuple
from
typing
import
ClassVar
,
NamedTuple
...
@@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
...
@@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
capability_tuple
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
return
cutlass_scaled_mm_supports_fp4
(
capability
)
def
convert_bf16_scales_to_fp8
(
quant_fp8
:
Callable
,
scales
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
expected by W4A8 GEMM kernels.
"""
assert
scales
.
is_contiguous
(),
(
f
"scale tensor must be contiguous, got
{
scales
.
stride
()
=
}
"
)
assert
scales
.
is_cuda
,
"scales must be on gpu"
orig_shape
=
scales
.
shape
k_groups
=
orig_shape
[
-
1
]
flat_scales
=
scales
.
view
(
-
1
,
k_groups
)
fp8_scales
,
chan_scales
=
quant_fp8
(
flat_scales
)
fp8_scales
=
(
fp8_scales
.
float
()
/
8.0
).
to
(
torch
.
float8_e4m3fn
)
chan_scales
*=
8.0
# restore original shape
fp8_scales
=
fp8_scales
.
view
(
orig_shape
)
chan_scales
=
chan_scales
.
view
(
orig_shape
[:
-
1
],
-
1
)
return
fp8_scales
,
chan_scales
def
convert_packed_uint4b8_to_signed_int4_inplace
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert int4b8 (packed to int32) to signed int4
"""
assert
t
.
is_cuda
,
"tensor must be on gpu"
assert
t
.
dtype
==
torch
.
int32
,
f
"expected int32 packed weights but got
{
t
.
dtype
}
"
# loop through the 8 4-bit nibbles in each int32 entry
for
i
in
range
(
8
):
shift
=
4
*
i
# extract the i-th 4-bit nibble
nib
=
(
t
>>
shift
)
&
0xF
# clear the original nibble by masking out
t
&=
~
(
0xF
<<
shift
)
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
# and update in-place
t
|=
((
nib
-
8
)
&
0xF
)
<<
shift
return
t
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment