Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9909726d
Unverified
Commit
9909726d
authored
Jul 01, 2025
by
czhu-cohere
Committed by
GitHub
Jul 01, 2025
Browse files
Enable ZP Support for Machete (#20268)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
22e9d420
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
5 deletions
+19
-5
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+2
-0
tests/kernels/quantization/test_machete_mm.py
tests/kernels/quantization/test_machete_mm.py
+1
-1
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
...or/layers/quantization/kernels/mixed_precision/machete.py
+16
-4
No files found.
benchmarks/kernels/benchmark_machete.py
View file @
9909726d
...
@@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
...
@@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
fn
=
lambda
:
ops
.
gptq_marlin_gemm
(
fn
=
lambda
:
ops
.
gptq_marlin_gemm
(
a
=
bt
.
a
,
a
=
bt
.
a
,
c
=
None
,
b_q_weight
=
w_q
,
b_q_weight
=
w_q
,
b_scales
=
w_s
,
b_scales
=
w_s
,
global_scale
=
None
,
b_zeros
=
w_zp
,
b_zeros
=
w_zp
,
g_idx
=
g_idx
,
g_idx
=
g_idx
,
perm
=
sort_indices
,
perm
=
sort_indices
,
...
...
tests/kernels/quantization/test_machete_mm.py
View file @
9909726d
...
@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
...
@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def
group_size_valid
(
shape
:
tuple
[
int
,
int
,
int
],
def
group_size_valid
(
shape
:
tuple
[
int
,
int
,
int
],
group_size
:
Optional
[
int
])
->
bool
:
group_size
:
Optional
[
int
])
->
bool
:
return
group_size
is
None
or
group_size
==
-
1
or
group_size
%
shape
[
2
]
==
0
return
group_size
is
None
or
group_size
==
-
1
or
shape
[
2
]
%
group_size
==
0
def
machete_quantize_and_pack
(
atype
:
torch
.
dtype
,
def
machete_quantize_and_pack
(
atype
:
torch
.
dtype
,
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
View file @
9909726d
...
@@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel):
return
False
,
"Act reordering currently not supported by Machete, "
\
return
False
,
"Act reordering currently not supported by Machete, "
\
"when the input features are partitioned across "
\
"when the input features are partitioned across "
\
"devices"
"devices"
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by Machete"
if
c
.
weight_type
not
in
query_machete_supported_quant_types
(
if
c
.
weight_type
not
in
query_machete_supported_quant_types
(
c
.
zero_points
):
c
.
zero_points
):
...
@@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel):
# note assumes that
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
c
=
self
.
config
c
=
self
.
config
...
@@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel):
x
.
data
=
x
.
data
.
contiguous
()
x
.
data
=
x
.
data
.
contiguous
()
return
x
return
x
def
transform_w_zp
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
)
x_unpacked
=
unpack_quantized_values_into_int32
(
x
.
data
,
c
.
weight_type
,
packed_dim
=
1
)
w_s
=
getattr
(
layer
,
self
.
w_s_name
).
data
# pre-apply scales to zero-points
x
.
data
=
(
-
1.0
*
w_s
*
(
x_unpacked
.
to
(
w_s
.
dtype
))).
contiguous
()
return
x
# Repack weights and scales for Machete
# Repack weights and scales for Machete
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
if
c
.
zero_points
:
self
.
_transform_param
(
layer
,
self
.
w_zp_name
,
transform_w_zp
)
def
apply_weights
(
self
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
c
=
self
.
config
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
w_q
,
w_s
,
w_zp
,
_
=
self
.
_get_weight_params
(
layer
)
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
...
@@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel):
output
=
ops
.
machete_mm
(
a
=
x_2d
,
output
=
ops
.
machete_mm
(
a
=
x_2d
,
b_q
=
w_q
,
b_q
=
w_q
,
b_type
=
c
.
weight_type
,
b_type
=
c
.
weight_type
,
b_group_zeros
=
None
,
b_group_zeros
=
w_zp
,
b_group_scales
=
w_s
,
b_group_scales
=
w_s
,
b_group_size
=
c
.
group_size
)
b_group_size
=
c
.
group_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment