Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2c2b140a
Unverified
Commit
2c2b140a
authored
Aug 26, 2025
by
czhu-cohere
Committed by
GitHub
Aug 26, 2025
Browse files
[quantization] use channel scales for w4a8 + misc fixes (#23570)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
c7c80af0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
14 deletions
+63
-14
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+40
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
+11
-2
vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py
...rs/quantization/kernels/mixed_precision/MPLinearKernel.py
+1
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
...or/layers/quantization/kernels/mixed_precision/cutlass.py
+11
-8
No files found.
tests/quantization/test_compressed_tensors.py
View file @
2c2b140a
...
@@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType
...
@@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType
from
tests.models.utils
import
check_logprobs_close
from
tests.models.utils
import
check_logprobs_close
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensors24
,
CompressedTensorsLinearMethod
,
CompressedTensors24
,
CompressedTensorsLinearMethod
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A
16
Fp
4
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A
8
Fp
8
,
CompressedTensorsW4A16
Sparse2
4
,
CompressedTensorsW
8A8Fp8
,
CompressedTensorsW4A16
Fp
4
,
CompressedTensorsW
4A16Sparse24
,
CompressedTensorsW8A8
Int
8
,
CompressedTensorsW8A
16Fp
8
,
CompressedTensorsW8A8
Fp
8
,
CompressedTensorsW8A
8Int
8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
)
cutlass_fp4_supported
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
@@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
...
@@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
assert
output
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"W4A8 FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args"
,
[
(
"czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e"
,
CompressedTensorsW4A8Fp8
)
])
def
test_compressed_tensors_w4a8_fp8
(
vllm_runner
,
args
):
model
,
scheme
=
args
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
for
proj
in
(
qkv_proj
,
o_proj
,
gate_up_proj
,
down_proj
):
assert
isinstance
(
proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
proj
.
scheme
,
scheme
)
assert
proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
proj
.
weight_scale
.
dtype
is
torch
.
float8_e4m3fn
assert
proj
.
weight_chan_scale
.
dtype
is
torch
.
float32
assert
proj
.
scheme
.
group_size
==
128
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
View file @
2c2b140a
...
@@ -79,7 +79,8 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
...
@@ -79,7 +79,8 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
act_type
=
torch
.
float8_e4m3fn
,
# always use fp8(e4m3)
act_type
=
torch
.
float8_e4m3fn
,
# always use fp8(e4m3)
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
zero_points
=
not
self
.
symmetric
,
zero_points
=
not
self
.
symmetric
,
has_g_idx
=
self
.
has_g_idx
has_g_idx
=
self
.
has_g_idx
,
out_type
=
params_dtype
)
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
...
@@ -122,7 +123,7 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
...
@@ -122,7 +123,7 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
torch
.
empty
(
torch
.
empty
(
output_size_per_partition
,
output_size_per_partition
,
scales_and_zp_size
,
scales_and_zp_size
,
dtype
=
params_dtype
,
dtype
=
torch
.
float8_e4m3fn
,
)
)
}
}
...
@@ -140,9 +141,17 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
...
@@ -140,9 +141,17 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
dtype
=
torch
.
int64
),
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
# per-channel scales
weight_chan_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
output_size_per_partition
,
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
layer
.
register_parameter
(
"weight_chan_scale"
,
weight_chan_scale
)
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"weight_packed"
,
w_q_param_name
=
"weight_packed"
,
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py
View file @
2c2b140a
...
@@ -20,6 +20,7 @@ class MPLinearLayerConfig:
...
@@ -20,6 +20,7 @@ class MPLinearLayerConfig:
group_size
:
int
group_size
:
int
zero_points
:
bool
zero_points
:
bool
has_g_idx
:
bool
has_g_idx
:
bool
out_type
:
Optional
[
torch
.
dtype
]
=
None
class
MPLinearKernel
(
ABC
):
class
MPLinearKernel
(
ABC
):
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
View file @
2c2b140a
...
@@ -60,13 +60,17 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
...
@@ -60,13 +60,17 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
if
in_features
%
128
or
out_features
%
128
:
if
in_features
%
128
or
out_features
%
128
:
return
False
,
"K and N must be divisible by 128, got "
\
return
False
,
"K and N must be divisible by 128, got "
\
f
"
{
c
.
partition_weight_shape
}
"
f
"
{
c
.
partition_weight_shape
}
"
if
c
.
out_type
!=
torch
.
bfloat16
:
return
False
,
"Only bfloat16 output type currently supported"
\
f
"got
{
c
.
out_type
=
}
"
return
True
,
None
return
True
,
None
# note assumes that
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
c
=
self
.
config
# TODO(czhu): optimize speed/mem usage
# TODO(czhu): optimize speed/mem usage
def
transform_w_q
(
x
):
def
transform_w_q
(
x
):
...
@@ -86,19 +90,15 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
...
@@ -86,19 +90,15 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# Encode/reorder weights and pack scales
# Encode/reorder weights and pack scales
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
self
.
_transform_param
(
layer
,
"weight_chan_scale"
,
lambda
x
:
x
)
# TODO(czhu): support loading channel scales
self
.
w_ch_s
=
torch
.
ones
((
c
.
partition_weight_shape
[
1
],
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
def
apply_weights
(
self
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
bias
is
None
,
"bias not supported by CUTLASS W4A8"
c
=
self
.
config
c
=
self
.
config
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
w_ch_s
=
layer
.
weight_chan_scale
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
...
@@ -109,6 +109,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
...
@@ -109,6 +109,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
b_group_scales
=
w_s
,
b_group_scales
=
w_s
,
b_group_size
=
c
.
group_size
,
b_group_size
=
c
.
group_size
,
a_token_scales
=
act_scales
,
a_token_scales
=
act_scales
,
b_channel_scales
=
self
.
w_ch_s
)
b_channel_scales
=
w_ch_s
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment