Unverified Commit 3f875fb5 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Restore Shardy Rule with CompoundFactor (#2167)



* Rework shardy rules

* WAR for compound factor=1
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent a92a0ad2
...@@ -410,27 +410,28 @@ class ActLuPrimitive(BasePrimitive): ...@@ -410,27 +410,28 @@ class ActLuPrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_" prefix = "ActLu_"
x_rank = len(value_types[0].shape) input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:]
# Here we pass len of output so that the scales are propagated correctly
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 output_shape, unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) x_axes = scale_rules.input_spec
out = (*x_axes[:-2], x_axes[-1]) # Correct input spec with act dim
scale_inv = scale_rules.rowwise_rule x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:]
out = scale_rules.input_spec
colwise_out = (prefix + "out_colwise",) colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",) colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
colwise_scale_inv = scale_rules.colwise_rule colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple( colwise_out = multidim_transpose(out, transpose_axis=-1)
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else: else:
colwise_out = out colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
# amax is always a unit tensor.
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
...@@ -438,7 +439,8 @@ class ActLuPrimitive(BasePrimitive): ...@@ -438,7 +439,8 @@ class ActLuPrimitive(BasePrimitive):
x_axes, x_axes,
("…1",), ("…1",),
), ),
(out, colwise_out, scale_inv, colwise_scale_inv, amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
) )
...@@ -883,26 +885,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -883,26 +885,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
prefix = "BaseDActLuDBiasQuantizePrimitive_" prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1]) dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",) colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else: else:
colwise_out = out colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(dz_axes, x_axes, ("…2",)), (dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
......
...@@ -712,7 +712,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -712,7 +712,7 @@ class GemmPrimitive(BasePrimitive):
del out_dtype, grad, use_split_accumulator del out_dtype, grad, use_split_accumulator
del mesh, result_types del mesh, result_types
prefix = "GemmPrimitive_" prefix = "Gemm_"
warnings.warn( warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
...@@ -746,13 +746,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -746,13 +746,8 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_specs = ("…1",) lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",) rhs_scale_specs = ("…2",)
if scaling_mode.is_1d_block_scaling(): if scaling_mode.is_1d_block_scaling():
# Shardy rules for MXFP8 scales cannot be related to the operands because of the lhs_scale_specs = lhs_specs
# global-unpadding and local-padding workflow. This can potentially insert expensive rhs_scale_specs = rhs_specs
# re-shards in the partition call later if the scales are not already sharded correctly.
lhs_scale_specs, rhs_scale_specs = map(
lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs),
(lhs_specs, rhs_specs),
)
lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
......
...@@ -581,9 +581,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -581,9 +581,9 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwdPrimitive_" prefix = "NormFwd_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
...@@ -604,6 +604,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -604,6 +604,7 @@ class NormFwdPrimitive(BasePrimitive):
mu, mu,
rsigma, rsigma,
), ),
**scale_rules.factor_sizes,
) )
......
...@@ -495,9 +495,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -495,9 +495,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
): ):
del out_dtype, scale_dtype, is_outer, mesh, result_types del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_" prefix = "DBiasQuantize_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), value_types[0].shape,
unique_var=prefix + "x", unique_var=prefix + "x",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -519,6 +519,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -519,6 +519,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), amax), (x_axes, ("…1",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
......
...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache ...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import operator import operator
import numpy as np import numpy as np
from jax.experimental.custom_partitioning import BATCHING from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
...@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC): ...@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod @abstractmethod
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization.
...@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,) return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization.
...@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
...@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,) return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
...@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_block_x * n_block_y,) return (n_block_x * n_block_y,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis input_rank = len(input_shape)
input_spec = [f"{unique_var}{i}" for i in range(input_rank)] input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank
colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
# This implementation needs to be updated for different block dims.
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. assert self._block_dims == (1, 32)
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the # We have to use two different factors in the two CompoundFactors because of Shardy
# underlying custom call and pad them in C++. Until that's implemented, the # verifier requirements, even though they are the same.
# Shardy rules for block scales have to be completely disconnected from the blocksizes = {}
# Shardy rules for the tensor they belong to. colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
# # We have to use two different factors in the two CompoundFactors because of Shardy if not input_shape[-1] == 32:
# # verifier requirements, even though they are the same. rowwise_var = input_spec[-1] + "_compound"
# rowwise_var = unique_var input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
# colwise_var = f"{unique_var}_" blocksizes["blocksize_x"] = 32
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") if not input_shape[flatten_axis - 1] == 32:
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
# # The rowwise and colwise scale tensors should be sharded the same way as the input. blocksizes["blocksize_y"] = 32
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy() # The rowwise and colwise scale tensors should be sharded the same way as the input.
# rowwise[-1] = rowwise_var # However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
# colwise = input_spec.copy() rowwise[-1] = rowwise_var
# colwise[flatten_axis - 1] = colwise_var
colwise = input_spec.copy()
# # This implementation needs to be updated for different block dims. colwise[flatten_axis - 1] = colwise_var
# assert self._block_dims == (1, 32)
return QuantizeShardyRules( return QuantizeShardyRules(
tuple(input_spec), tuple(input_spec),
tuple(rowwise), tuple(rowwise),
tuple(colwise), tuple(colwise),
{}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, blocksizes,
) )
...@@ -697,18 +709,22 @@ class ScalingMode(Enum): ...@@ -697,18 +709,22 @@ class ScalingMode(Enum):
return self._get_impl().get_quantize_layout(usage) return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1 self,
input_shape,
unique_var,
flatten_axis=-1,
) -> Tuple[Tuple[str]]: ) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis)
def get_grouped_scale_shape_2x( def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment