"docs/vscode:/vscode.git/clone" did not exist on "540c0368b14ddd8d3efac0b182761bf6600f104f"
Unverified Commit 9909726d authored by czhu-cohere's avatar czhu-cohere Committed by GitHub
Browse files

Enable ZP Support for Machete (#20268)


Signed-off-by: default avatarczhu-cohere <conway.zhu@cohere.com>
parent 22e9d420
...@@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ...@@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
fn = lambda: ops.gptq_marlin_gemm( fn = lambda: ops.gptq_marlin_gemm(
a=bt.a, a=bt.a,
c=None,
b_q_weight=w_q, b_q_weight=w_q,
b_scales=w_s, b_scales=w_s,
global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,
perm=sort_indices, perm=sort_indices,
......
...@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): ...@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def group_size_valid(shape: tuple[int, int, int], def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool: group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0 return group_size is None or group_size == -1 or shape[2] % group_size == 0
def machete_quantize_and_pack(atype: torch.dtype, def machete_quantize_and_pack(atype: torch.dtype,
......
...@@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel):
return False, "Act reordering currently not supported by Machete, "\ return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\ "when the input features are partitioned across "\
"devices" "devices"
if c.zero_points:
return False, "Zero points currently not supported by Machete"
if c.weight_type not in query_machete_supported_quant_types( if c.weight_type not in query_machete_supported_quant_types(
c.zero_points): c.zero_points):
...@@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel):
# note assumes that # note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1} # `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config c = self.config
...@@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel):
x.data = x.data.contiguous() x.data = x.data.contiguous()
return x return x
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=1)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
return x
# Repack weights and scales for Machete # Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer) w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1]) x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
...@@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel):
output = ops.machete_mm(a=x_2d, output = ops.machete_mm(a=x_2d,
b_q=w_q, b_q=w_q,
b_type=c.weight_type, b_type=c.weight_type,
b_group_zeros=None, b_group_zeros=w_zp,
b_group_scales=w_s, b_group_scales=w_s,
b_group_size=c.group_size) b_group_size=c.group_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment