Unverified Commit 6117b20c authored by Johannes Reifferscheid's avatar Johannes Reifferscheid Committed by GitHub
Browse files

Add experimental Shardy support. (#1642)



* Add experimental Shardy support.

Production use is not yet recommended.

---------
Signed-off-by: default avatarJohannes Reifferscheid <jreiffers@nvidia.com>
parent 98b4c0d9
......@@ -21,3 +21,15 @@ do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16_shardy --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8_shardy --num-process=$NUM_GPUS --process-id=$i &
done
wait
......@@ -258,6 +258,8 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
......@@ -441,6 +443,9 @@ def encoder_parser(args):
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args)
......@@ -451,10 +456,9 @@ class TestEncoder(unittest.TestCase):
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
def setUp(self):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
......@@ -503,6 +507,34 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -238,6 +238,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
......@@ -409,6 +410,9 @@ def encoder_parser(args):
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args)
......@@ -419,10 +423,9 @@ class TestEncoder(unittest.TestCase):
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
def setUp(self):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
......@@ -446,6 +449,24 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -343,6 +343,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
if args.process_id == 0:
nltk.download("punkt_tab")
......@@ -565,6 +566,9 @@ def encoder_parser(args):
default=0,
help="the ID number of the current process (default: 0)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args)
......@@ -573,7 +577,7 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
def exec(self, use_fp8, fp8_recipe):
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing"""
args = encoder_parser([])
......@@ -589,6 +593,7 @@ class TestEncoder(unittest.TestCase):
args.num_process = num_gpu
args.process_id = self.process_id
args.fp8_recipe = fp8_recipe
args.enable_shardy = enable_shardy
return train_and_evaluate(args)
......@@ -614,6 +619,22 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -330,10 +330,9 @@ class TestEncoder(unittest.TestCase):
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
def setUp(self):
"""Run 3 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
......
......@@ -25,3 +25,5 @@ filterwarnings=
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning
ignore:Scan loop is disabled for fused ring attention.*:UserWarning
ignore:jax.extend.ffi.register_ffi_target is deprecated
ignore:jax.extend.ffi.ffi_lowering is deprecated
......@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
# for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
def impl_test_self_attn(
self,
device_count,
mesh_shape,
......@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
dropout_prob = 0.0
is_training = True
......@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
)
runner.test_backward()
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
use_shardy=True,
)
class TestDistributedCrossAttn:
......@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
runner.test_backward()
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
[
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
]
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
]
class TestDistributedContextParallelSelfAttn:
def impl_test_context_parallel_attn(
......@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
cp_strategy,
use_shardy,
use_scan_ring=False,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
if use_scan_ring:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
jax.config.update("jax_use_shardy_partitioner", use_shardy)
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
......@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
runner.test_backward()
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_allgather_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.ALL_GATHER,
use_shardy=True,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
def test_context_parallel_allgather_attn(
self,
device_count,
......@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
use_shardy=False,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
@pytest.mark.parametrize(
"use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
......@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
load_balanced,
use_scan,
):
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
CPStrategy.RING,
use_shardy=False,
use_scan_ring=use_scan,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_ring_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
)
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
return
class TestReorderCausalLoadBalancing:
......
......@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm(
self,
device_count,
......@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
zero_centered_gamma,
shard_weights,
fp8_recipe,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6
ln_type = "layernorm"
q_dtype = jnp.float8_e4m3fn
......@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_rmsnorm(
self,
device_count,
......@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
dtype,
shard_weights,
fp8_recipe,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6
ln_type = "rmsnorm"
q_dtype = jnp.float8_e4m3fn
......
......@@ -144,16 +144,10 @@ class TestDistributedLayernormMLP:
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm"
......@@ -257,9 +251,59 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close",
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy=False,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=recipe.DelayedScaling(),
use_shardy=True,
)
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8,
fp8_recipe,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape
layernorm_type = "rmsnorm"
......@@ -322,9 +366,19 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy
):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=use_shardy,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -345,4 +399,5 @@ class TestDistributedLayernormMLP:
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=False,
)
......@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(batch, sqelen)
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
if not bad_sharding:
x_pspec = PartitionSpec(
......@@ -45,6 +47,10 @@ class TestDistributedSoftmax:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
)
if broadcast_batch_mask:
mask_pspec = PartitionSpec(None, None, None, None)
else:
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
......@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
def test_softmax(
def impl_test_softmax(
self,
device_count,
mesh_shape,
......@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy,
):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
f"{str(w)}"
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
broadcast_batch_mask=broadcast_batch_mask,
use_shardy=True,
)
......@@ -10,6 +10,7 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
import transformer_engine_jax
......@@ -406,6 +407,54 @@ class ActLuPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else:
colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor.
amax = ("l",)
return SdyShardingRule(
(
x_axes,
"…1",
),
(out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
)
register_primitive(ActLuPrimitive)
......@@ -819,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else:
colwise_out = tuple(x_axes)
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",)
amax = ("…4",)
return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DActLuDBiasQuantizePrimitive)
......
......@@ -14,6 +14,7 @@ import jax
import jax.numpy as jnp
from jax import dtypes, lax
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
......@@ -42,6 +43,7 @@ from ..sharding import (
get_mesh_axis_rank,
get_all_mesh_axes,
num_of_devices,
with_sharding_constraint,
)
......@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del mesh, result_types
# Keep in sync with `infer_sharding_from_operands`.
# We only need the first input. Fill up the rest with placeholders.
input_spec = [(f"…{x}",) for x in range(len(value_types))]
# The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
# instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
rng_sharding = (f"…{len(value_types)}",)
if config.qkv_layout.is_qkvpacked():
input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
input_spec[0] = ("…0", "seqlen", "head", "hidden")
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
out_sharding = ("…0", "seqlen", "head", "hidden")
if is_packed_softmax:
softmax_aux_sharding = ("…0", "seqlen", "head", "i")
else:
softmax_aux_sharding = ("…0", "head", "seqlen", "i")
return SdyShardingRule(
tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
)
register_primitive(FusedAttnFwdPrimitive)
......@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
return SdyShardingRule(input_spec, output_spec)
register_primitive(FusedAttnBwdPrimitive)
......@@ -2436,13 +2476,15 @@ def fused_attn_fwd(
primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
return primitive.bind(
output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive,
bias,
seed,
*seq_desc_flatten,
config=fused_config,
)
rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
return (output, softmax_aux, rng_state)
def fused_attn_bwd(
......
......@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta):
"""
return NotImplemented
@staticmethod
@abstractmethod
def shardy_sharding_rule(*args):
"""
Returns the sharding rule for this primitive.
"""
del args
return "... -> ..."
def register_primitive(cls):
"""
......@@ -123,7 +132,9 @@ def register_primitive(cls):
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(
infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition
infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition,
sharding_rule=cls.shardy_sharding_rule,
)
mlir.register_lowering(
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
......
......@@ -12,6 +12,7 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec
......@@ -519,6 +520,57 @@ class NormFwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del (
zero_centered_gamma,
epsilon,
out_dtype,
scale_dtype,
scale_shapes,
is_outer,
mesh,
result_types,
)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1
)
x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",)
colwise_out = out if is_2x else ("…4",)
rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",)
return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
amax,
mu,
rsigma,
),
**scale_rules.factor_sizes,
)
register_primitive(NormFwdPrimitive)
......@@ -722,6 +774,11 @@ class NormBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l"
register_primitive(NormBwdPrimitive)
......
......@@ -10,6 +10,7 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
import transformer_engine_jax
......@@ -470,6 +471,48 @@ class DBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
)
x_axes = scale_rules.input_spec
colwise_scale_inv = scale_rules.colwise_rule
out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
amax = ("m",)
return SdyShardingRule(
(x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DBiasQuantizePrimitive)
......
......@@ -330,6 +330,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledSoftmaxFwdPrimitive)
......@@ -400,6 +405,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledSoftmaxBwdPrimitive)
......@@ -525,6 +535,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...1, ...2 -> ...1"
register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
......@@ -596,6 +611,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
......@@ -682,6 +702,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
result_infos,
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
......@@ -761,6 +786,11 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
result_infos,
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......
......@@ -16,13 +16,33 @@ from typing import Tuple, Dict
from functools import reduce
import operator
from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["ScalingMode"]
__all__ = ["QuantizeShardyRules", "ScalingMode"]
@dataclass
class QuantizeShardyRules:
"""Information necessary to shard scale tensors with Shardy.
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
"""
input_spec: Tuple[str]
rowwise_rule: Tuple[str]
colwise_rule: Tuple[str]
factor_sizes: Dict[str, int]
class ScalingModeMetadataImpl(ABC):
......@@ -59,6 +79,21 @@ class ScalingModeMetadataImpl(ABC):
The shape for scale tensors
"""
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
......@@ -95,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
del data_shape, is_colwise
return (1,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode.
......@@ -217,6 +269,45 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
input_spec = [f"x{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var = unique_var
colwise_var = f"{unique_var}_"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32},
)
@dataclass(frozen=True)
@register_pytree_node_class
......@@ -290,6 +381,20 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def __eq__(self, other):
"""Compare this scaling mode with another.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment