Unverified Commit 15985680 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ Misc ] Rs/compressed tensors cleanup (#5432)


Co-authored-by: default avatarmgoin <michael@neuralmagic.com>
Co-authored-by: default avatarDipika Sikka <dipikasikka1@gmail.com>
parent d74674bb
...@@ -26,7 +26,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -26,7 +26,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return [] return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16] return [torch.float16, torch.bfloat16]
# Need to figure it out # Need to figure it out
def get_min_capability(self) -> int: def get_min_capability(self) -> int:
......
...@@ -64,10 +64,9 @@ class CompressedTensorsW4A16(CompressedTensorsScheme): ...@@ -64,10 +64,9 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
"input_dim": 1, "input_dim": 1,
"output_dim": 0, "output_dim": 0,
"packed_dim": 1, "packed_dim": 1,
"pack_factor": pack_factor "pack_factor": pack_factor,
"weight_loader": weight_loader
}) })
set_weight_attrs(weight, {"weight_loader": weight_loader})
layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_packed", weight)
weight_scale = Parameter( weight_scale = Parameter(
...@@ -79,8 +78,9 @@ class CompressedTensorsW4A16(CompressedTensorsScheme): ...@@ -79,8 +78,9 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
requires_grad=False, requires_grad=False,
) )
set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) set_weight_attrs(
set_weight_attrs(weight_scale, { weight_scale, {
"weight_loader": weight_loader,
"input_dim": weight_scale_dim, "input_dim": weight_scale_dim,
"output_dim": 0 "output_dim": 0
}) })
...@@ -92,7 +92,10 @@ class CompressedTensorsW4A16(CompressedTensorsScheme): ...@@ -92,7 +92,10 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
requires_grad=False) requires_grad=False)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader}) set_weight_attrs(weight_shape, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
......
...@@ -48,9 +48,6 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): ...@@ -48,9 +48,6 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
weight_scale_dim = sum( weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1 output_partition_sizes) if is_tensor_partitioned else 1
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim, weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32), dtype=torch.float32),
requires_grad=False) requires_grad=False)
...@@ -61,21 +58,22 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): ...@@ -61,21 +58,22 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
requires_grad=False) requires_grad=False)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(
set_weight_attrs(weight, {"weight_loader": weight_loader}) weight, {
set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) "input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs( set_weight_attrs(
weight_scale, { weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter, "shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes "logical_widths": output_partition_sizes
}) })
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
......
...@@ -39,22 +39,16 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): ...@@ -39,22 +39,16 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
# TODO: remove zero_point parameters once the configs given remove them
is_tensor_partitioned = len(output_partition_sizes) != 1 is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum( weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1 output_partition_sizes) if is_tensor_partitioned else 1
input_scale = Parameter(torch.empty(1, dtype=torch.float32), input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False) requires_grad=False)
input_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim, weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32), dtype=torch.float32),
requires_grad=False) requires_grad=False)
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
...@@ -72,11 +66,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): ...@@ -72,11 +66,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
"weight_loader": weight_loader, "weight_loader": weight_loader,
"ignore_warning": True, "ignore_warning": True,
}) })
layer.register_parameter("input_zero_point", input_zero_point)
set_weight_attrs(input_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs( set_weight_attrs(
weight_scale, { weight_scale, {
...@@ -85,11 +74,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): ...@@ -85,11 +74,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
"logical_widths": output_partition_sizes, "logical_widths": output_partition_sizes,
"ignore_warning": True, "ignore_warning": True,
}) })
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True
})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight weight = layer.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment