Unverified Commit 3f875fb5 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Restore Shardy Rule with CompoundFactor (#2167)



* Rework shardy rules

* WAR for compound factor=1
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent a92a0ad2
......@@ -410,27 +410,28 @@ class ActLuPrimitive(BasePrimitive):
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_"
x_rank = len(value_types[0].shape)
prefix = "ActLu_"
input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:]
# Here we pass len of output so that the scales are propagated correctly
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
output_shape, unique_var=prefix + "x", flatten_axis=-1
)
x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule
x_axes = scale_rules.input_spec
# Correct input spec with act dim
x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:]
out = scale_rules.input_spec
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x:
colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
colwise_out = multidim_transpose(out, transpose_axis=-1)
else:
colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
# amax is always a unit tensor.
amax = (prefix + "amax",)
return SdyShardingRule(
......@@ -438,7 +439,8 @@ class ActLuPrimitive(BasePrimitive):
x_axes,
("…1",),
),
(out, colwise_out, scale_inv, colwise_scale_inv, amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
)
......@@ -883,26 +885,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types,
):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
prefix = "BaseDActLuDBiasQuantizePrimitive_"
prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
)
x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else:
colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
return SdyShardingRule(
(dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
......
......@@ -712,7 +712,7 @@ class GemmPrimitive(BasePrimitive):
del out_dtype, grad, use_split_accumulator
del mesh, result_types
prefix = "GemmPrimitive_"
prefix = "Gemm_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
......@@ -746,13 +746,8 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",)
if scaling_mode.is_1d_block_scaling():
# Shardy rules for MXFP8 scales cannot be related to the operands because of the
# global-unpadding and local-padding workflow. This can potentially insert expensive
# re-shards in the partition call later if the scales are not already sharded correctly.
lhs_scale_specs, rhs_scale_specs = map(
lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs),
(lhs_specs, rhs_specs),
)
lhs_scale_specs = lhs_specs
rhs_scale_specs = rhs_specs
lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
......
......@@ -581,9 +581,9 @@ class NormFwdPrimitive(BasePrimitive):
result_types,
)
prefix = "NormFwdPrimitive_"
prefix = "NormFwd_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1
value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1
)
x_axes = scale_rules.input_spec
......@@ -604,6 +604,7 @@ class NormFwdPrimitive(BasePrimitive):
mu,
rsigma,
),
**scale_rules.factor_sizes,
)
......
......@@ -495,9 +495,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
):
del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_"
prefix = "DBiasQuantize_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape),
value_types[0].shape,
unique_var=prefix + "x",
flatten_axis=flatten_axis,
)
......@@ -519,6 +519,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
return SdyShardingRule(
(x_axes, ("…1",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
......
......@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import operator
import numpy as np
from jax.experimental.custom_partitioning import BATCHING
from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
......@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
......@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
......@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
......@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
......@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_block_x * n_block_y,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = [f"{unique_var}{i}" for i in range(input_rank)]
rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)]
colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the
# underlying custom call and pad them in C++. Until that's implemented, the
# Shardy rules for block scales have to be completely disconnected from the
# Shardy rules for the tensor they belong to.
# # We have to use two different factors in the two CompoundFactors because of Shardy
# # verifier requirements, even though they are the same.
# rowwise_var = unique_var
# colwise_var = f"{unique_var}_"
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# rowwise[-1] = rowwise_var
# colwise = input_spec.copy()
# colwise[flatten_axis - 1] = colwise_var
# # This implementation needs to be updated for different block dims.
# assert self._block_dims == (1, 32)
input_rank = len(input_shape)
input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
flatten_axis = (flatten_axis + input_rank) % input_rank
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
blocksizes = {}
colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
if not input_shape[-1] == 32:
rowwise_var = input_spec[-1] + "_compound"
input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
blocksizes["blocksize_x"] = 32
if not input_shape[flatten_axis - 1] == 32:
colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
blocksizes["blocksize_y"] = 32
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{}, # {"block_size_rowwise": 32, "block_size_colwise": 32},
blocksizes,
)
......@@ -697,18 +709,22 @@ class ScalingMode(Enum):
return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
self,
input_shape,
unique_var,
flatten_axis=-1,
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis)
def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment