Unverified Commit ac797994 authored by Selali's avatar Selali Committed by GitHub
Browse files

[Bugfix] Fix for ROCM compressed tensor support (#11561)

parent dde1fa18
...@@ -41,10 +41,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -41,10 +41,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
) )
if current_platform.is_rocm(): if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=max_w_scale, weight_scale=max_w_scale,
input_scale=layer.input_scale) input_scale=input_scale)
if input_scale is not None: if input_scale is not None:
layer.input_scale = Parameter(input_scale, layer.input_scale = Parameter(input_scale,
requires_grad=False) requires_grad=False)
...@@ -57,11 +59,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -57,11 +59,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight = layer.weight weight = layer.weight
if current_platform.is_rocm(): if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
input_scale=layer.input_scale) input_scale=input_scale)
if input_scale is not None: if input_scale is not None:
layer.input_scale = Parameter(input_scale, layer.input_scale = Parameter(input_scale,
requires_grad=False) requires_grad=False)
...@@ -76,7 +80,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -76,7 +80,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
raise ValueError(f"Unknown quantization strategy {self.strategy}") raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(), layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False) requires_grad=False)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment