# Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization.
self.update(amax)
self.update(amax)
returnScaledTensorFactory.create_1x(
returnScaledTensorFactory.create_1x(
data=clipped_scaled_x,
data=clipped_scaled_x,
...
@@ -494,7 +509,7 @@ class BlockScaleQuantizer(Quantizer):
...
@@ -494,7 +509,7 @@ class BlockScaleQuantizer(Quantizer):
# We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate.
# Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype.
assertscale_dtype==jnp.float8_e4m3fn,"Only float8_e4m3fn is supported for scale_dtype"
@@ -100,10 +100,19 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -100,10 +100,19 @@ class ScalingModeMetadataImpl(ABC):
The data type used for scale tensors
The data type used for scale tensors
"""
"""
@abstractmethod
defget_data_layout(self)->str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
@abstractmethod
@abstractmethod
defget_scale_shape(
defget_scale_shape(
self,
self,
data_shape:Tuple[int,...],
data_shape:Tuple[int,...],
data_layout:str="N",
is_colwise:bool=False,
is_colwise:bool=False,
is_padded:bool=True,
is_padded:bool=True,
flatten_axis:int=-1,
flatten_axis:int=-1,
...
@@ -112,6 +121,7 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -112,6 +121,7 @@ class ScalingModeMetadataImpl(ABC):
Args:
Args:
data_shape: The shape of the tensor being quantized
data_shape: The shape of the tensor being quantized
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
is_padded: Whether to return padded shape
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
...
@@ -152,14 +162,19 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -152,14 +162,19 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod
@abstractmethod
defget_shardy_sharding_rules(
defget_shardy_sharding_rules(
self,input_rank,unique_var,flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
)->QuantizeShardyRules:
)->QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns:
Returns:
The Shardy rules for the scaling mode
The Shardy rules for the scaling mode
...
@@ -180,12 +195,22 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -180,12 +195,22 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
"""
returnjnp.float32
returnjnp.float32
defget_data_layout(self)->str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return"NN"
defget_scale_shape(
defget_scale_shape(
self,
self,
data_shape:Tuple[int,...],
data_shape:Tuple[int,...],
data_layout:str="N",
is_colwise:bool=False,
is_colwise:bool=False,
is_padded:bool=True,
is_padded:bool=True,
flatten_axis:int=-1,
flatten_axis:int=-1,
broadcast_2d_scale_shape_to_1d:bool=True,
)->Tuple[int,...]:
)->Tuple[int,...]:
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
...
@@ -198,7 +223,14 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -198,7 +223,14 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
Returns:
The shape for scale tensors - (1,)
The shape for scale tensors - (1,)
"""
"""
deldata_shape,is_colwise,is_padded,flatten_axis
del(
data_shape,
data_layout,
is_colwise,
is_padded,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
)
return(0,)
return(0,)
@lru_cache(maxsize=4)
@lru_cache(maxsize=4)
...
@@ -232,20 +264,25 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -232,20 +264,25 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return(n_groups,)
return(n_groups,)
defget_shardy_sharding_rules(
defget_shardy_sharding_rules(
self,input_rank,unique_var,flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
)->QuantizeShardyRules:
)->QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
@@ -264,25 +301,37 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -264,25 +301,37 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
"""
returnjnp.float32
returnjnp.float32
defget_data_layout(self)->str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return"NT"
defget_scale_shape(
defget_scale_shape(
self,
self,
data_shape:Tuple[int,...],
data_shape:Tuple[int,...],
data_layout:str="N",
is_colwise:bool=False,
is_colwise:bool=False,
is_padded:bool=True,
is_padded:bool=True,
flatten_axis:int=-1,
flatten_axis:int=-1,
broadcast_2d_scale_shape_to_1d:bool=True,
)->Tuple[int,...]:
)->Tuple[int,...]:
"""Get the shape for scale tensors in delayed scaling.
"""Get the shape for scale tensors in delayed scaling.
Args:
Args:
data_shape: The shape of the tensor being scaled
data_shape: The shape of the tensor being scaled
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
"""
self._block_dims=block_dims
self._block_dims=block_dims
self._scale_dtype=scale_dtype
self._block_alignment=(128,4)
self._block_alignment=(128,4)
self._data_layout=data_layout
defget_scale_dtype(self)->jnp.dtype:
defget_scale_dtype(self)->jnp.dtype:
"""Get the data type for scale tensors in block scaling.
"""Get the data type for scale tensors in block scaling.
...
@@ -374,7 +432,15 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -374,7 +432,15 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
Returns:
The data type used for scale tensors (float8_e8m0fnu)
The data type used for scale tensors (float8_e8m0fnu)
"""
"""
returnjnp.float8_e8m0fnu
returnself._scale_dtype
defget_data_layout(self)->str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.