Unverified Commit ff884e20 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)



* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent be1f647c
......@@ -40,7 +40,11 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors.
......@@ -48,7 +52,7 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
......@@ -69,7 +73,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return jnp.float32
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling.
......@@ -77,6 +85,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors - (1,)
......@@ -113,8 +122,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
return jnp.float8_e8m0fnu
def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
if len(data_shape) > 1:
# handle last dim
assert data_shape[-1] % scale_block_dim == 0
last = data_shape[-1] // scale_block_dim
scale_shape = (last,)
assert n_scale_blocks % last == 0
n_scale_blocks //= last
# handle middle dim, exclude first and last
for mid in reversed(data_shape[1:-1]):
scale_shape = (mid,) + scale_shape
assert n_scale_blocks % mid == 0
n_scale_blocks //= mid
scale_shape = (n_scale_blocks,) + scale_shape
else:
scale_shape = (n_scale_blocks,)
assert len(scale_shape) == len(
data_shape
), f"scale_shape {scale_shape}, data_shape {data_shape}"
return scale_shape
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling.
......@@ -122,6 +158,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
......@@ -135,35 +172,48 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
seq_axis = len(data_shape) - 2
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert (
data_shape[seq_axis] % block_x == 0
), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}"
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
)
assert (
data_shape[-1] % block_y == 0
), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1"
), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
n_block_seq = data_shape[seq_axis] // block_x
n_block_y = data_shape[-1] // block_y
flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq
assert flattened_first_dim % block_x == 0, (
f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
f" {data_shape} - should be divisible by block_x {block_x}"
)
assert flattened_last_dim % block_y == 0, (
"Flattened last dim - mutiplication of"
f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
f" divisible by block_y {block_y}"
)
# Padding
n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x
n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y
n_block_x = int(flattened_first_dim / block_x)
n_block_y = int(flattened_last_dim / block_y)
out_shape = ()
for i in range(seq_axis):
d = data_shape[i]
out_shape += (d,)
assert n_flat_first_dim % d == 0
n_flat_first_dim //= d
# padding
n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
out_shape += (n_flat_first_dim, n_block_y)
first_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[:flatten_axis], n_block_x, block_x
)
last_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[flatten_axis:], n_block_y, block_y
)
return out_shape
return (*first_dim_scale_shape, *last_dim_scale_shape)
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
......@@ -208,34 +258,40 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]:
def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded
data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis
)
colwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis
)
colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded)
return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]:
def get_scale_shape(
self, data_shape, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded)
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def __eq__(self, other):
"""Compare this scaling mode with another.
......
......@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer
......@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
ValueError: If called on a tensor that doesn't support column-wise access
"""
@abstractmethod
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
@register_pytree_node_class
@dataclass
......@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
"""
data: jnp.ndarray
......@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: jnp.dtype
_dq_func: Callable
is_colwise: bool
layout: str
data_layout: str
flatten_axis: int = -1
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
......@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
flatten_axis = (
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert (
0 < flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
)
if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, (
......@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv)
aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout)
aux_data = (
self.scaling_mode,
self.dq_dtype,
self._dq_func,
self.is_colwise,
self.data_layout,
self.flatten_axis,
)
return (children, aux_data)
def dequantize(self):
......@@ -183,6 +214,46 @@ class ScaledTensor1x(ScaledTensor):
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
# axis_names were given for N layout, so needs to be transpose for T layout
if self.data_layout == "T":
assert self.flatten_axis > 0
flatten_axis = -self.flatten_axis
axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis])
else:
axis_names = logical_axis_names
data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
# TODO(Phuong): Handle padding !?
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else:
scale_inv = self.scale_inv
# TODO(Phuong): constaint padded scale_inv?
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
is_colwise=self.is_colwise,
data_layout=self.data_layout,
flatten_axis=self.flatten_axis,
)
@register_pytree_node_class
@dataclass
......@@ -233,6 +304,27 @@ class ScaledTensor2x(ScaledTensor):
"""
return self.colwise_tensor
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
@dataclass
class ScaledTensorFactory:
......@@ -244,7 +336,13 @@ class ScaledTensorFactory:
@staticmethod
def create_1x(
data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N"
data,
scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
is_colwise=False,
data_layout="N",
flatten_axis=-1,
):
"""Creates a single-scale quantized tensor.
......@@ -254,13 +352,16 @@ class ScaledTensorFactory:
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout)
return ScaledTensor1x(
data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis
)
@staticmethod
def create_2x(
......@@ -270,7 +371,8 @@ class ScaledTensorFactory:
colwise_scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
layout="NN",
data_layout="NN",
flatten_axis=-1,
):
"""Creates a double-scale quantized tensor.
......@@ -281,7 +383,8 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor2x instance
......@@ -294,7 +397,8 @@ class ScaledTensorFactory:
dq_dtype,
dq_func,
is_colwise=False,
layout=layout[0],
data_layout=data_layout[0],
flatten_axis=flatten_axis,
)
colwise_tensor = ScaledTensor1x(
colwise_data,
......@@ -303,7 +407,8 @@ class ScaledTensorFactory:
dq_dtype,
dq_func,
is_colwise=True,
layout=layout[1],
data_layout=data_layout[1],
flatten_axis=flatten_axis,
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -315,8 +420,9 @@ class ScaledTensorFactory:
colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode,
dq_dtype: jnp.dtype = jnp.bfloat16,
layout: str = "NN",
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE,
data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
flatten_axis: int = -1,
):
"""Creates a scaled tensor based on the quantization axis.
......@@ -327,13 +433,13 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE)
data_layout: The data_layout specification (default: "NN")
q_layout: The quantization axis (default: ROWWISE)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
"""
if q_axis == QuantizeAxis.ROWWISE_COLWISE:
if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x(
data,
scale_inv,
......@@ -341,12 +447,19 @@ class ScaledTensorFactory:
colwise_scale_inv,
scaling_mode,
dq_dtype,
layout=layout,
data_layout=data_layout,
flatten_axis=flatten_axis,
)
is_colwise = q_axis == QuantizeAxis.COLWISE
is_colwise = q_layout == QuantizeLayout.COLWISE
return ScaledTensorFactory.create_1x(
data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0]
data,
scale_inv,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
)
......@@ -360,24 +473,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns:
The tensor with applied sharding constraints
"""
if isinstance(x, ScaledTensor1x):
return ScaledTensor1x(
data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names),
scale_inv=x.scale_inv,
scaling_mode=x.scaling_mode,
dq_dtype=x.dq_dtype,
_dq_func=x._dq_func,
is_colwise=x.is_colwise,
layout=x.layout,
)
if isinstance(x, ScaledTensor2x):
return ScaledTensor2x(
rowwise_tensor=with_sharding_constraint_by_logical_axes(
x.rowwise_tensor, logical_axis_names
),
colwise_tensor=with_sharding_constraint_by_logical_axes(
x.colwise_tensor, logical_axis_names
),
)
if isinstance(x, ScaledTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
......@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
Convert logical axes to PartitionSpec
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names]
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec
......@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if logical_axis_names is None:
if not logical_axis_names:
return x
assert len(x.shape) == len(logical_axis_names)
......@@ -315,3 +319,25 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if not logical_axes:
logical_axes = (None,) * ndim
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
return non_contracting_logical_axes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment