Unverified Commit 3d0ea80a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] `ScaledTensor1x` to store `amax` (#2117)



* added amax as an optional arg
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 07db17b5
......@@ -1037,7 +1037,7 @@ def act_lu(
out = out.reshape(output_shape)
if noop_scaled_tensor:
return ScaledTensorFactory.create_2x(
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype
out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype
)
return out
......
......@@ -1324,7 +1324,12 @@ def normalization_fwd(
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype
output,
None,
output,
None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
mu,
rsigma,
......
......@@ -591,7 +591,7 @@ def _quantize_dbias_impl(
None,
x,
None,
ScalingMode.NO_SCALING,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
......
......@@ -494,7 +494,7 @@ class BlockScaleQuantizer(Quantizer):
return ScaledTensorFactory.create_1x(
x_q,
scales_q,
self.scaling_mode,
scaling_mode=self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
......@@ -640,11 +640,11 @@ class GroupedQuantizer(Quantizer):
return ScaledTensorFactory.create_1x(
grouped_data,
grouped_scale_inv,
self.scaling_mode,
tensor_list[0].dq_dtype,
tensor_list[0].is_colwise,
tensor_list[0].data_layout,
tensor_list[0].flatten_axis,
scaling_mode=self.scaling_mode,
dq_dtype=tensor_list[0].dq_dtype,
is_colwise=tensor_list[0].is_colwise,
data_layout=tensor_list[0].data_layout,
flatten_axis=tensor_list[0].flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
......
......@@ -104,6 +104,7 @@ class ScaledTensor1x(ScaledTensor):
Attributes:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode used for quantization
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
......@@ -114,6 +115,7 @@ class ScaledTensor1x(ScaledTensor):
data: jnp.ndarray
scale_inv: jnp.ndarray
amax: jnp.ndarray
scaling_mode: ScalingMode
dq_dtype: jnp.dtype
_dq_func: Callable
......@@ -152,7 +154,7 @@ class ScaledTensor1x(ScaledTensor):
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv)
children = (self.data, self.scale_inv, self.amax)
aux_data = (
self.scaling_mode,
self.dq_dtype,
......@@ -224,6 +226,7 @@ class ScaledTensor1x(ScaledTensor):
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
amax=self.amax,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
......@@ -255,6 +258,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self,
data,
scale_inv,
amax,
group_sizes,
scaling_mode,
dq_dtype,
......@@ -270,7 +274,15 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.original_shape = original_shape
self.group_axis = group_axis
super().__init__(
data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
_dq_func,
is_colwise,
data_layout,
flatten_axis,
)
def __post_init__(self):
......@@ -308,7 +320,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv, self.group_sizes)
children = (self.data, self.scale_inv, self.amax, self.group_sizes)
aux_data = (
self.scaling_mode,
self.dq_dtype,
......@@ -413,7 +425,8 @@ class ScaledTensorFactory:
def create_1x(
data,
scale_inv,
scaling_mode,
amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16,
is_colwise=False,
data_layout="N",
......@@ -427,18 +440,22 @@ class ScaledTensorFactory:
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
group_sizes: Arra of ints containing the size of each group (default: None)
group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
"""
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None:
......@@ -468,6 +485,7 @@ class ScaledTensorFactory:
return GroupedScaledTensor1x(
data=data,
scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=dequantizer.grouped_dequantize,
......@@ -487,6 +505,7 @@ class ScaledTensorFactory:
return ScaledTensor1x(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
dequantizer.dequantize,
......@@ -501,7 +520,8 @@ class ScaledTensorFactory:
scale_inv,
colwise_data,
colwise_scale_inv,
scaling_mode,
amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16,
data_layout="NN",
flatten_axis=-1,
......@@ -516,6 +536,7 @@ class ScaledTensorFactory:
scale_inv: The row-wise inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN")
......@@ -527,10 +548,14 @@ class ScaledTensorFactory:
Returns:
A ScaledTensor2x instance
"""
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=False,
......@@ -543,6 +568,7 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=True,
......@@ -560,7 +586,8 @@ class ScaledTensorFactory:
scale_inv: jnp.ndarray,
colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode,
amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
......@@ -594,6 +621,7 @@ class ScaledTensorFactory:
scale_inv,
colwise_data,
colwise_scale_inv,
amax,
scaling_mode,
dq_dtype,
data_layout=data_layout,
......@@ -608,6 +636,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
......@@ -621,6 +650,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment