Unverified Commit 7f21e805 authored by rongfu.leng's avatar rongfu.leng Committed by GitHub
Browse files

[Misc] add group_size is -1 in awq quantization (#18910)


Signed-off-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
parent 5a864163
...@@ -101,7 +101,13 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -101,7 +101,13 @@ class AWQLinearMethod(LinearMethodBase):
output_partition_sizes: list[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0: # Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
if input_size_per_partition % group_size != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -127,9 +133,11 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -127,9 +133,11 @@ class AWQLinearMethod(LinearMethodBase):
packed_factor=self.quant_config.pack_factor, packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader) weight_loader=weight_loader)
num_groups = input_size_per_partition // group_size
qzeros = PackedvLLMParameter( qzeros = PackedvLLMParameter(
data=torch.empty( data=torch.empty(
input_size_per_partition // self.quant_config.group_size, num_groups,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
), ),
...@@ -140,7 +148,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -140,7 +148,7 @@ class AWQLinearMethod(LinearMethodBase):
weight_loader=weight_loader) weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty( scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size, num_groups,
output_size_per_partition, output_size_per_partition,
dtype=params_dtype, dtype=params_dtype,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment