Unverified Commit ff884e20 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)



* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent be1f647c
...@@ -40,7 +40,11 @@ class ScalingModeMetadataImpl(ABC): ...@@ -40,7 +40,11 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod @abstractmethod
def get_scale_shape( def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors. """Get the shape for scale tensors.
...@@ -48,7 +52,7 @@ class ScalingModeMetadataImpl(ABC): ...@@ -48,7 +52,7 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
...@@ -69,7 +73,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -69,7 +73,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return jnp.float32 return jnp.float32
def get_scale_shape( def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling. """Get the shape for scale tensors in delayed scaling.
...@@ -77,6 +85,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -77,6 +85,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being scaled data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors - (1,) The shape for scale tensors - (1,)
...@@ -113,8 +122,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -113,8 +122,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
""" """
return jnp.float8_e8m0fnu return jnp.float8_e8m0fnu
def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
if len(data_shape) > 1:
# handle last dim
assert data_shape[-1] % scale_block_dim == 0
last = data_shape[-1] // scale_block_dim
scale_shape = (last,)
assert n_scale_blocks % last == 0
n_scale_blocks //= last
# handle middle dim, exclude first and last
for mid in reversed(data_shape[1:-1]):
scale_shape = (mid,) + scale_shape
assert n_scale_blocks % mid == 0
n_scale_blocks //= mid
scale_shape = (n_scale_blocks,) + scale_shape
else:
scale_shape = (n_scale_blocks,)
assert len(scale_shape) == len(
data_shape
), f"scale_shape {scale_shape}, data_shape {data_shape}"
return scale_shape
def get_scale_shape( def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling. """Get the shape for scale tensors in block scaling.
...@@ -122,6 +158,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -122,6 +158,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
...@@ -135,35 +172,48 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -135,35 +172,48 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x, block_y = self._block_dims block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment alignment_x, alignment_y = block_alignment
seq_axis = len(data_shape) - 2 if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert ( assert (
data_shape[seq_axis] % block_x == 0 0 < flatten_axis < len(data_shape)
), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}" ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
)
assert ( assert (
data_shape[-1] % block_y == 0 data_shape[-1] % block_y == 0
), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1" ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1 flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
n_block_seq = data_shape[seq_axis] // block_x flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
n_block_y = data_shape[-1] // block_y
n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq assert flattened_first_dim % block_x == 0, (
f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
f" {data_shape} - should be divisible by block_x {block_x}"
)
assert flattened_last_dim % block_y == 0, (
"Flattened last dim - mutiplication of"
f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
f" divisible by block_y {block_y}"
)
# Padding n_block_x = int(flattened_first_dim / block_x)
n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x n_block_y = int(flattened_last_dim / block_y)
n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y
out_shape = () # padding
for i in range(seq_axis): n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
d = data_shape[i] n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
out_shape += (d,)
assert n_flat_first_dim % d == 0
n_flat_first_dim //= d
out_shape += (n_flat_first_dim, n_block_y) first_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[:flatten_axis], n_block_x, block_x
)
last_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[flatten_axis:], n_block_y, block_y
)
return out_shape return (*first_dim_scale_shape, *last_dim_scale_shape)
# (Phuong: Map the NVTEScalingMode value to the ScalingMode # (Phuong: Map the NVTEScalingMode value to the ScalingMode
...@@ -208,34 +258,40 @@ class ScalingMode(Enum): ...@@ -208,34 +258,40 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_scale_dtype() return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]: def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling. """Get shapes for both row-wise and column-wise scaling.
Args: Args:
data_shape: Shape of the data tensor data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
rowwise_scale_shape = self.get_scale_shape( rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis
)
colwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis
) )
colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded)
return (rowwise_scale_shape, colwise_scale_shape) return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]: def get_scale_shape(
self, data_shape, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode. """Get the shape for scale tensors in this mode.
Args: Args:
data_shape: Shape of the data tensor data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded) return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def __eq__(self, other): def __eq__(self, other):
"""Compare this scaling mode with another. """Compare this scaling mode with another.
......
...@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod ...@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer from .dequantizer import Dequantizer
...@@ -84,6 +84,17 @@ class ScaledTensor(ABC): ...@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
ValueError: If called on a tensor that doesn't support column-wise access ValueError: If called on a tensor that doesn't support column-wise access
""" """
@abstractmethod
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
...@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor): ...@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: The data type for dequantized values dq_dtype: The data type for dequantized values
_dq_func: The dequantization function _dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
""" """
data: jnp.ndarray data: jnp.ndarray
...@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor): ...@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: jnp.dtype dq_dtype: jnp.dtype
_dq_func: Callable _dq_func: Callable
is_colwise: bool is_colwise: bool
layout: str data_layout: str
flatten_axis: int = -1
def __post_init__(self): def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization. """Validates and adjusts the scale_inv shape after initialization.
...@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor): ...@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary. and quantization direction. Pads the scale_inv if necessary.
""" """
flatten_axis = (
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert (
0 < flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_scale_shape( expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
) )
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
if self.scale_inv.shape != expected_scale_shape: if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, ( assert self.scale_inv.shape == expected_unpadded_scale_shape, (
...@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor): ...@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
A tuple containing (children, aux_data) for tree operations A tuple containing (children, aux_data) for tree operations
""" """
children = (self.data, self.scale_inv) children = (self.data, self.scale_inv)
aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout) aux_data = (
self.scaling_mode,
self.dq_dtype,
self._dq_func,
self.is_colwise,
self.data_layout,
self.flatten_axis,
)
return (children, aux_data) return (children, aux_data)
def dequantize(self): def dequantize(self):
...@@ -183,6 +214,46 @@ class ScaledTensor1x(ScaledTensor): ...@@ -183,6 +214,46 @@ class ScaledTensor1x(ScaledTensor):
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
# axis_names were given for N layout, so needs to be transpose for T layout
if self.data_layout == "T":
assert self.flatten_axis > 0
flatten_axis = -self.flatten_axis
axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis])
else:
axis_names = logical_axis_names
data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
# TODO(Phuong): Handle padding !?
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else:
scale_inv = self.scale_inv
# TODO(Phuong): constaint padded scale_inv?
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
is_colwise=self.is_colwise,
data_layout=self.data_layout,
flatten_axis=self.flatten_axis,
)
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
...@@ -233,6 +304,27 @@ class ScaledTensor2x(ScaledTensor): ...@@ -233,6 +304,27 @@ class ScaledTensor2x(ScaledTensor):
""" """
return self.colwise_tensor return self.colwise_tensor
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
@dataclass @dataclass
class ScaledTensorFactory: class ScaledTensorFactory:
...@@ -244,7 +336,13 @@ class ScaledTensorFactory: ...@@ -244,7 +336,13 @@ class ScaledTensorFactory:
@staticmethod @staticmethod
def create_1x( def create_1x(
data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N" data,
scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
is_colwise=False,
data_layout="N",
flatten_axis=-1,
): ):
"""Creates a single-scale quantized tensor. """Creates a single-scale quantized tensor.
...@@ -254,13 +352,16 @@ class ScaledTensorFactory: ...@@ -254,13 +352,16 @@ class ScaledTensorFactory:
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False) is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N") data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x instance A ScaledTensor1x instance
""" """
dq_func = Dequantizer.funcs.get(scaling_mode) dq_func = Dequantizer.funcs.get(scaling_mode)
return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout) return ScaledTensor1x(
data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis
)
@staticmethod @staticmethod
def create_2x( def create_2x(
...@@ -270,7 +371,8 @@ class ScaledTensorFactory: ...@@ -270,7 +371,8 @@ class ScaledTensorFactory:
colwise_scale_inv, colwise_scale_inv,
scaling_mode, scaling_mode,
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
layout="NN", data_layout="NN",
flatten_axis=-1,
): ):
"""Creates a double-scale quantized tensor. """Creates a double-scale quantized tensor.
...@@ -281,7 +383,8 @@ class ScaledTensorFactory: ...@@ -281,7 +383,8 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor2x instance A ScaledTensor2x instance
...@@ -294,7 +397,8 @@ class ScaledTensorFactory: ...@@ -294,7 +397,8 @@ class ScaledTensorFactory:
dq_dtype, dq_dtype,
dq_func, dq_func,
is_colwise=False, is_colwise=False,
layout=layout[0], data_layout=data_layout[0],
flatten_axis=flatten_axis,
) )
colwise_tensor = ScaledTensor1x( colwise_tensor = ScaledTensor1x(
colwise_data, colwise_data,
...@@ -303,7 +407,8 @@ class ScaledTensorFactory: ...@@ -303,7 +407,8 @@ class ScaledTensorFactory:
dq_dtype, dq_dtype,
dq_func, dq_func,
is_colwise=True, is_colwise=True,
layout=layout[1], data_layout=data_layout[1],
flatten_axis=flatten_axis,
) )
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -315,8 +420,9 @@ class ScaledTensorFactory: ...@@ -315,8 +420,9 @@ class ScaledTensorFactory:
colwise_scale_inv: jnp.ndarray, colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode, scaling_mode: ScalingMode,
dq_dtype: jnp.dtype = jnp.bfloat16, dq_dtype: jnp.dtype = jnp.bfloat16,
layout: str = "NN", data_layout: str = "NN",
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE, q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
flatten_axis: int = -1,
): ):
"""Creates a scaled tensor based on the quantization axis. """Creates a scaled tensor based on the quantization axis.
...@@ -327,13 +433,13 @@ class ScaledTensorFactory: ...@@ -327,13 +433,13 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE) q_layout: The quantization axis (default: ROWWISE)
Returns: Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
""" """
if q_axis == QuantizeAxis.ROWWISE_COLWISE: if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x( return ScaledTensorFactory.create_2x(
data, data,
scale_inv, scale_inv,
...@@ -341,12 +447,19 @@ class ScaledTensorFactory: ...@@ -341,12 +447,19 @@ class ScaledTensorFactory:
colwise_scale_inv, colwise_scale_inv,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
layout=layout, data_layout=data_layout,
flatten_axis=flatten_axis,
) )
is_colwise = q_axis == QuantizeAxis.COLWISE is_colwise = q_layout == QuantizeLayout.COLWISE
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0] data,
scale_inv,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
) )
...@@ -360,24 +473,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . ...@@ -360,24 +473,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns: Returns:
The tensor with applied sharding constraints The tensor with applied sharding constraints
""" """
if isinstance(x, ScaledTensor1x): if isinstance(x, ScaledTensor):
return ScaledTensor1x( return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names),
scale_inv=x.scale_inv,
scaling_mode=x.scaling_mode,
dq_dtype=x.dq_dtype,
_dq_func=x._dq_func,
is_colwise=x.is_colwise,
layout=x.layout,
)
if isinstance(x, ScaledTensor2x):
return ScaledTensor2x(
rowwise_tensor=with_sharding_constraint_by_logical_axes(
x.rowwise_tensor, logical_axis_names
),
colwise_tensor=with_sharding_constraint_by_logical_axes(
x.colwise_tensor, logical_axis_names
),
)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
...@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names): ...@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
Convert logical axes to PartitionSpec Convert logical axes to PartitionSpec
""" """
rules = get_sharding_map_logic_axis_to_mesh_axis() rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names] # mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names) pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec return pspec
...@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t ...@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
""" """
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
""" """
if logical_axis_names is None: if not logical_axis_names:
return x return x
assert len(x.shape) == len(logical_axis_names) assert len(x.shape) == len(logical_axis_names)
...@@ -315,3 +319,25 @@ class ShardingType(Enum): ...@@ -315,3 +319,25 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row") TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if not logical_axes:
logical_axes = (None,) * ndim
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
return non_contracting_logical_axes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment