Unverified Commit 3d0ea80a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] `ScaledTensor1x` to store `amax` (#2117)



* added amax as an optional arg
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 07db17b5
...@@ -1037,7 +1037,7 @@ def act_lu( ...@@ -1037,7 +1037,7 @@ def act_lu(
out = out.reshape(output_shape) out = out.reshape(output_shape)
if noop_scaled_tensor: if noop_scaled_tensor:
return ScaledTensorFactory.create_2x( return ScaledTensorFactory.create_2x(
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype
) )
return out return out
......
...@@ -1324,7 +1324,12 @@ def normalization_fwd( ...@@ -1324,7 +1324,12 @@ def normalization_fwd(
if quantizer is None and noop_scaled_tensor: if quantizer is None and noop_scaled_tensor:
return ( return (
ScaledTensorFactory.create_2x( ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype output,
None,
output,
None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
), ),
mu, mu,
rsigma, rsigma,
......
...@@ -591,7 +591,7 @@ def _quantize_dbias_impl( ...@@ -591,7 +591,7 @@ def _quantize_dbias_impl(
None, None,
x, x,
None, None,
ScalingMode.NO_SCALING, scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=x.dtype, dq_dtype=x.dtype,
data_layout="NN", data_layout="NN",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
......
...@@ -494,7 +494,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -494,7 +494,7 @@ class BlockScaleQuantizer(Quantizer):
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
x_q, x_q,
scales_q, scales_q,
self.scaling_mode, scaling_mode=self.scaling_mode,
is_colwise=is_colwise, is_colwise=is_colwise,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
...@@ -640,11 +640,11 @@ class GroupedQuantizer(Quantizer): ...@@ -640,11 +640,11 @@ class GroupedQuantizer(Quantizer):
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
grouped_data, grouped_data,
grouped_scale_inv, grouped_scale_inv,
self.scaling_mode, scaling_mode=self.scaling_mode,
tensor_list[0].dq_dtype, dq_dtype=tensor_list[0].dq_dtype,
tensor_list[0].is_colwise, is_colwise=tensor_list[0].is_colwise,
tensor_list[0].data_layout, data_layout=tensor_list[0].data_layout,
tensor_list[0].flatten_axis, flatten_axis=tensor_list[0].flatten_axis,
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
......
...@@ -104,6 +104,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -104,6 +104,7 @@ class ScaledTensor1x(ScaledTensor):
Attributes: Attributes:
data: The quantized tensor data data: The quantized tensor data
scale_inv: The inverse scaling factors scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode used for quantization scaling_mode: The scaling mode used for quantization
dq_dtype: The data type for dequantized values dq_dtype: The data type for dequantized values
_dq_func: The dequantization function _dq_func: The dequantization function
...@@ -114,6 +115,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -114,6 +115,7 @@ class ScaledTensor1x(ScaledTensor):
data: jnp.ndarray data: jnp.ndarray
scale_inv: jnp.ndarray scale_inv: jnp.ndarray
amax: jnp.ndarray
scaling_mode: ScalingMode scaling_mode: ScalingMode
dq_dtype: jnp.dtype dq_dtype: jnp.dtype
_dq_func: Callable _dq_func: Callable
...@@ -152,7 +154,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -152,7 +154,7 @@ class ScaledTensor1x(ScaledTensor):
Returns: Returns:
A tuple containing (children, aux_data) for tree operations A tuple containing (children, aux_data) for tree operations
""" """
children = (self.data, self.scale_inv) children = (self.data, self.scale_inv, self.amax)
aux_data = ( aux_data = (
self.scaling_mode, self.scaling_mode,
self.dq_dtype, self.dq_dtype,
...@@ -224,6 +226,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -224,6 +226,7 @@ class ScaledTensor1x(ScaledTensor):
return ScaledTensor1x( return ScaledTensor1x(
data=data, data=data,
scale_inv=scale_inv, scale_inv=scale_inv,
amax=self.amax,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype, dq_dtype=self.dq_dtype,
_dq_func=self._dq_func, _dq_func=self._dq_func,
...@@ -255,6 +258,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -255,6 +258,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self, self,
data, data,
scale_inv, scale_inv,
amax,
group_sizes, group_sizes,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
...@@ -270,7 +274,15 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -270,7 +274,15 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.original_shape = original_shape self.original_shape = original_shape
self.group_axis = group_axis self.group_axis = group_axis
super().__init__( super().__init__(
data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
_dq_func,
is_colwise,
data_layout,
flatten_axis,
) )
def __post_init__(self): def __post_init__(self):
...@@ -308,7 +320,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -308,7 +320,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
Returns: Returns:
A tuple containing (children, aux_data) for tree operations A tuple containing (children, aux_data) for tree operations
""" """
children = (self.data, self.scale_inv, self.group_sizes) children = (self.data, self.scale_inv, self.amax, self.group_sizes)
aux_data = ( aux_data = (
self.scaling_mode, self.scaling_mode,
self.dq_dtype, self.dq_dtype,
...@@ -413,7 +425,8 @@ class ScaledTensorFactory: ...@@ -413,7 +425,8 @@ class ScaledTensorFactory:
def create_1x( def create_1x(
data, data,
scale_inv, scale_inv,
scaling_mode, amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
is_colwise=False, is_colwise=False,
data_layout="N", data_layout="N",
...@@ -427,18 +440,22 @@ class ScaledTensorFactory: ...@@ -427,18 +440,22 @@ class ScaledTensorFactory:
Args: Args:
data: The quantized tensor data data: The quantized tensor data
scale_inv: The inverse scaling factors scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False) is_colwise: Whether to use column-wise quantization (default: False)
data_layout: The data_layout specification (default: "N") data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
group_sizes: Arra of ints containing the size of each group (default: None) group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
Returns: Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
""" """
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None: if group_sizes is not None:
...@@ -468,6 +485,7 @@ class ScaledTensorFactory: ...@@ -468,6 +485,7 @@ class ScaledTensorFactory:
return GroupedScaledTensor1x( return GroupedScaledTensor1x(
data=data, data=data,
scale_inv=scale_inv, scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
_dq_func=dequantizer.grouped_dequantize, _dq_func=dequantizer.grouped_dequantize,
...@@ -487,6 +505,7 @@ class ScaledTensorFactory: ...@@ -487,6 +505,7 @@ class ScaledTensorFactory:
return ScaledTensor1x( return ScaledTensor1x(
data, data,
scale_inv, scale_inv,
amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
dequantizer.dequantize, dequantizer.dequantize,
...@@ -501,7 +520,8 @@ class ScaledTensorFactory: ...@@ -501,7 +520,8 @@ class ScaledTensorFactory:
scale_inv, scale_inv,
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
scaling_mode, amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
data_layout="NN", data_layout="NN",
flatten_axis=-1, flatten_axis=-1,
...@@ -516,6 +536,7 @@ class ScaledTensorFactory: ...@@ -516,6 +536,7 @@ class ScaledTensorFactory:
scale_inv: The row-wise inverse scaling factors scale_inv: The row-wise inverse scaling factors
colwise_data: The column-wise quantized data colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors colwise_scale_inv: The column-wise inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
...@@ -527,10 +548,14 @@ class ScaledTensorFactory: ...@@ -527,10 +548,14 @@ class ScaledTensorFactory:
Returns: Returns:
A ScaledTensor2x instance A ScaledTensor2x instance
""" """
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x( rowwise_tensor = ScaledTensorFactory.create_1x(
data, data,
scale_inv, scale_inv,
amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=False, is_colwise=False,
...@@ -543,6 +568,7 @@ class ScaledTensorFactory: ...@@ -543,6 +568,7 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=True, is_colwise=True,
...@@ -560,7 +586,8 @@ class ScaledTensorFactory: ...@@ -560,7 +586,8 @@ class ScaledTensorFactory:
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
colwise_data: jnp.ndarray, colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray, colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode, amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16, dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN", data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
...@@ -594,6 +621,7 @@ class ScaledTensorFactory: ...@@ -594,6 +621,7 @@ class ScaledTensorFactory:
scale_inv, scale_inv,
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
data_layout=data_layout, data_layout=data_layout,
...@@ -608,6 +636,7 @@ class ScaledTensorFactory: ...@@ -608,6 +636,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=is_colwise, is_colwise=is_colwise,
...@@ -621,6 +650,7 @@ class ScaledTensorFactory: ...@@ -621,6 +650,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
data, data,
scale_inv, scale_inv,
amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=is_colwise, is_colwise=is_colwise,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment