"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "49f7c1db03605d15999cbeae1cc7404b76b855c5"
Unverified Commit 6117b20c authored by Johannes Reifferscheid's avatar Johannes Reifferscheid Committed by GitHub
Browse files

Add experimental Shardy support. (#1642)



* Add experimental Shardy support.

Production use is not yet recommended.

---------
Signed-off-by: default avatarJohannes Reifferscheid <jreiffers@nvidia.com>
parent 98b4c0d9
...@@ -21,3 +21,15 @@ do ...@@ -21,3 +21,15 @@ do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i & pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i &
done done
wait wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16_shardy --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8_shardy --num-process=$NUM_GPUS --process-id=$i &
done
wait
...@@ -258,6 +258,8 @@ def get_state_sharding(state, params_sharding): ...@@ -258,6 +258,8 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count() num_gpu = jax.local_device_count()
...@@ -441,6 +443,9 @@ def encoder_parser(args): ...@@ -441,6 +443,9 @@ def encoder_parser(args):
parser.add_argument( parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
) )
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args) return parser.parse_args(args)
...@@ -451,10 +456,9 @@ class TestEncoder(unittest.TestCase): ...@@ -451,10 +456,9 @@ class TestEncoder(unittest.TestCase):
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod def setUp(self):
def setUpClass(cls):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
...@@ -503,6 +507,34 @@ class TestEncoder(unittest.TestCase): ...@@ -503,6 +507,34 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -238,6 +238,7 @@ def get_state_sharding(state, params_sharding): ...@@ -238,6 +238,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count() num_gpu = jax.local_device_count()
...@@ -409,6 +410,9 @@ def encoder_parser(args): ...@@ -409,6 +410,9 @@ def encoder_parser(args):
default="DelayedScaling", default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)", help="Use FP8 recipe (default: DelayedScaling)",
) )
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args) return parser.parse_args(args)
...@@ -419,10 +423,9 @@ class TestEncoder(unittest.TestCase): ...@@ -419,10 +423,9 @@ class TestEncoder(unittest.TestCase):
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod def setUp(self):
def setUpClass(cls):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
...@@ -446,6 +449,24 @@ class TestEncoder(unittest.TestCase): ...@@ -446,6 +449,24 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -343,6 +343,7 @@ def get_state_sharding(state, params_sharding): ...@@ -343,6 +343,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
if args.process_id == 0: if args.process_id == 0:
nltk.download("punkt_tab") nltk.download("punkt_tab")
...@@ -565,6 +566,9 @@ def encoder_parser(args): ...@@ -565,6 +566,9 @@ def encoder_parser(args):
default=0, default=0,
help="the ID number of the current process (default: 0)", help="the ID number of the current process (default: 0)",
) )
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args) return parser.parse_args(args)
...@@ -573,7 +577,7 @@ def encoder_parser(args): ...@@ -573,7 +577,7 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
def exec(self, use_fp8, fp8_recipe): def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
args = encoder_parser([]) args = encoder_parser([])
...@@ -589,6 +593,7 @@ class TestEncoder(unittest.TestCase): ...@@ -589,6 +593,7 @@ class TestEncoder(unittest.TestCase):
args.num_process = num_gpu args.num_process = num_gpu
args.process_id = self.process_id args.process_id = self.process_id
args.fp8_recipe = fp8_recipe args.fp8_recipe = fp8_recipe
args.enable_shardy = enable_shardy
return train_and_evaluate(args) return train_and_evaluate(args)
...@@ -614,6 +619,22 @@ class TestEncoder(unittest.TestCase): ...@@ -614,6 +619,22 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling") result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -330,10 +330,9 @@ class TestEncoder(unittest.TestCase): ...@@ -330,10 +330,9 @@ class TestEncoder(unittest.TestCase):
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod def setUp(self):
def setUpClass(cls): """Run 3 epochs for testing"""
"""Run 4 epochs for testing""" self.args = encoder_parser(["--epochs", "3"])
cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
......
...@@ -25,3 +25,5 @@ filterwarnings= ...@@ -25,3 +25,5 @@ filterwarnings=
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning
ignore:Scan loop is disabled for fused ring attention.*:UserWarning ignore:Scan loop is disabled for fused ring attention.*:UserWarning
ignore:jax.extend.ffi.register_ffi_target is deprecated
ignore:jax.extend.ffi.ffi_lowering is deprecated
...@@ -48,31 +48,7 @@ class TestDistributedSelfAttn: ...@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
# for loss and dbias # for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) def impl_test_self_attn(
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self, self,
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -83,7 +59,9 @@ class TestDistributedSelfAttn: ...@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
...@@ -137,6 +115,80 @@ class TestDistributedSelfAttn: ...@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
) )
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
use_shardy=True,
)
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
...@@ -203,37 +255,23 @@ class TestDistributedCrossAttn: ...@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize( DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
) pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
@pytest.mark.parametrize( pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
"data_shape", pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
[ pytest.param(
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes. QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), ),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ]
],
) DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
@pytest.mark.parametrize("kv_groups", [1, 8]) # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
@pytest.mark.parametrize( pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
"qkv_layout, attn_mask_type", ]
[
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
class TestDistributedContextParallelSelfAttn: class TestDistributedContextParallelSelfAttn:
def impl_test_context_parallel_attn( def impl_test_context_parallel_attn(
...@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn: ...@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
cp_strategy, cp_strategy,
use_shardy,
use_scan_ring=False,
): ):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
if use_scan_ring:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
jax.config.update("jax_use_shardy_partitioner", use_shardy)
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None bias_shape = None
dropout_prob = 0.0 dropout_prob = 0.0
...@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn: ...@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
runner.test_backward() runner.test_backward()
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_allgather_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.ALL_GATHER,
use_shardy=True,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
def test_context_parallel_allgather_attn( def test_context_parallel_allgather_attn(
self, self,
device_count, device_count,
...@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
): ):
if qkv_layout.is_thd(): self.impl_test_context_parallel_attn(
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
mesh_axes, mesh_axes,
...@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn: ...@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
CPStrategy.ALL_GATHER, CPStrategy.ALL_GATHER,
use_shardy=False,
) )
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_scan", "use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
...@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn: ...@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
load_balanced, load_balanced,
use_scan, use_scan,
): ):
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
self.impl_test_context_parallel_attn( self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn: ...@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
CPStrategy.RING, CPStrategy.RING,
use_shardy=False,
use_scan_ring=use_scan,
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
def test_context_parallel_ring_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
qkv_layout,
):
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=True,
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
) )
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
return
class TestReorderCausalLoadBalancing: class TestReorderCausalLoadBalancing:
......
...@@ -86,6 +86,7 @@ class TestDistributedLayernorm: ...@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
@pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm( def test_layernorm(
self, self,
device_count, device_count,
...@@ -97,7 +98,9 @@ class TestDistributedLayernorm: ...@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
zero_centered_gamma, zero_centered_gamma,
shard_weights, shard_weights,
fp8_recipe, fp8_recipe,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6 epsilon = 1e-6
ln_type = "layernorm" ln_type = "layernorm"
q_dtype = jnp.float8_e4m3fn q_dtype = jnp.float8_e4m3fn
...@@ -168,6 +171,7 @@ class TestDistributedLayernorm: ...@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_rmsnorm( def test_rmsnorm(
self, self,
device_count, device_count,
...@@ -178,7 +182,9 @@ class TestDistributedLayernorm: ...@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
dtype, dtype,
shard_weights, shard_weights,
fp8_recipe, fp8_recipe,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
epsilon = 1e-6 epsilon = 1e-6
ln_type = "rmsnorm" ln_type = "rmsnorm"
q_dtype = jnp.float8_e4m3fn q_dtype = jnp.float8_e4m3fn
......
...@@ -144,16 +144,10 @@ class TestDistributedLayernormMLP: ...@@ -144,16 +144,10 @@ class TestDistributedLayernormMLP:
) )
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) def _test_layernorm_mlp_grad(
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm" layernorm_type = "rmsnorm"
...@@ -257,9 +251,59 @@ class TestDistributedLayernormMLP: ...@@ -257,9 +251,59 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy=False,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=recipe.DelayedScaling(),
use_shardy=True,
)
def _test_layernorm_mlp( def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8,
fp8_recipe,
use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
layernorm_type = "rmsnorm" layernorm_type = "rmsnorm"
...@@ -322,9 +366,19 @@ class TestDistributedLayernormMLP: ...@@ -322,9 +366,19 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): @pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy
):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=use_shardy,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -345,4 +399,5 @@ class TestDistributedLayernormMLP: ...@@ -345,4 +399,5 @@ class TestDistributedLayernormMLP:
dtype, dtype,
use_fp8=True, use_fp8=True,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
use_shardy=False,
) )
...@@ -28,14 +28,16 @@ class TestDistributedSoftmax: ...@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding): def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen) mask = make_causal_mask(batch, sqelen)
else: else:
mask = make_self_mask(batch, sqelen) mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
if not bad_sharding: if not bad_sharding:
x_pspec = PartitionSpec( x_pspec = PartitionSpec(
...@@ -45,7 +47,11 @@ class TestDistributedSoftmax: ...@@ -45,7 +47,11 @@ class TestDistributedSoftmax:
x_pspec = PartitionSpec( x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
) )
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
if broadcast_batch_mask:
mask_pspec = PartitionSpec(None, None, None, None)
else:
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
...@@ -67,16 +73,7 @@ class TestDistributedSoftmax: ...@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
output = jax.nn.softmax(x * scale_factor) output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output) return jnp.mean(output)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) def impl_test_softmax(
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
def test_softmax(
self, self,
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -87,15 +84,20 @@ class TestDistributedSoftmax: ...@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
broadcast_batch_mask,
use_shardy,
): ):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial( target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
) )
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs( (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
) )
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
...@@ -129,4 +131,70 @@ class TestDistributedSoftmax: ...@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
assert "Sharding the hidden dimension is not supported" in str(w), ( assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for " "Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension." "unsupported sharding in the hidden dimension."
f"{str(w)}"
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy=False,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
bad_sharding,
broadcast_batch_mask,
):
self.impl_test_softmax(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
broadcast_batch_mask=broadcast_batch_mask,
use_shardy=True,
)
...@@ -10,6 +10,7 @@ from packaging import version ...@@ -10,6 +10,7 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
...@@ -406,6 +407,54 @@ class ActLuPrimitive(BasePrimitive): ...@@ -406,6 +407,54 @@ class ActLuPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else:
colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor.
amax = ("l",)
return SdyShardingRule(
(
x_axes,
"…1",
),
(out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
)
register_primitive(ActLuPrimitive) register_primitive(ActLuPrimitive)
...@@ -819,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -819,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else:
colwise_out = tuple(x_axes)
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",)
amax = ("…4",)
return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DActLuDBiasQuantizePrimitive) register_primitive(DActLuDBiasQuantizePrimitive)
......
...@@ -14,6 +14,7 @@ import jax ...@@ -14,6 +14,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
...@@ -42,6 +43,7 @@ from ..sharding import ( ...@@ -42,6 +43,7 @@ from ..sharding import (
get_mesh_axis_rank, get_mesh_axis_rank,
get_all_mesh_axes, get_all_mesh_axes,
num_of_devices, num_of_devices,
with_sharding_constraint,
) )
...@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
impl = partial(FusedAttnFwdPrimitive.impl, config=config) impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del mesh, result_types
# Keep in sync with `infer_sharding_from_operands`.
# We only need the first input. Fill up the rest with placeholders.
input_spec = [(f"…{x}",) for x in range(len(value_types))]
# The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
# instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
rng_sharding = (f"…{len(value_types)}",)
if config.qkv_layout.is_qkvpacked():
input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
input_spec[0] = ("…0", "seqlen", "head", "hidden")
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
out_sharding = ("…0", "seqlen", "head", "hidden")
if is_packed_softmax:
softmax_aux_sharding = ("…0", "seqlen", "head", "i")
else:
softmax_aux_sharding = ("…0", "head", "seqlen", "i")
return SdyShardingRule(
tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
)
register_primitive(FusedAttnFwdPrimitive) register_primitive(FusedAttnFwdPrimitive)
...@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
return SdyShardingRule(input_spec, output_spec)
register_primitive(FusedAttnBwdPrimitive) register_primitive(FusedAttnBwdPrimitive)
...@@ -2436,13 +2476,15 @@ def fused_attn_fwd( ...@@ -2436,13 +2476,15 @@ def fused_attn_fwd(
primitive = FusedRingAttnFwdPrimitive.outer_primitive primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
return primitive.bind( output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
seed, seed,
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
) )
rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
return (output, softmax_aux, rng_state)
def fused_attn_bwd( def fused_attn_bwd(
......
...@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta):
""" """
return NotImplemented return NotImplemented
@staticmethod
@abstractmethod
def shardy_sharding_rule(*args):
"""
Returns the sharding rule for this primitive.
"""
del args
return "... -> ..."
def register_primitive(cls): def register_primitive(cls):
""" """
...@@ -123,7 +132,9 @@ def register_primitive(cls): ...@@ -123,7 +132,9 @@ def register_primitive(cls):
batching.primitive_batchers[outer_p] = cls.batcher batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition( outer_p_lower.def_partition(
infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition,
sharding_rule=cls.shardy_sharding_rule,
) )
mlir.register_lowering( mlir.register_lowering(
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
......
...@@ -12,6 +12,7 @@ from packaging import version ...@@ -12,6 +12,7 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -519,6 +520,57 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -519,6 +520,57 @@ class NormFwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del (
zero_centered_gamma,
epsilon,
out_dtype,
scale_dtype,
scale_shapes,
is_outer,
mesh,
result_types,
)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1
)
x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",)
colwise_out = out if is_2x else ("…4",)
rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",)
return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
amax,
mu,
rsigma,
),
**scale_rules.factor_sizes,
)
register_primitive(NormFwdPrimitive) register_primitive(NormFwdPrimitive)
...@@ -722,6 +774,11 @@ class NormBwdPrimitive(BasePrimitive): ...@@ -722,6 +774,11 @@ class NormBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l"
register_primitive(NormBwdPrimitive) register_primitive(NormBwdPrimitive)
......
...@@ -10,6 +10,7 @@ from packaging import version ...@@ -10,6 +10,7 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
...@@ -470,6 +471,48 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -470,6 +471,48 @@ class DBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
)
x_axes = scale_rules.input_spec
colwise_scale_inv = scale_rules.colwise_rule
out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
amax = ("m",)
return SdyShardingRule(
(x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DBiasQuantizePrimitive) register_primitive(DBiasQuantizePrimitive)
......
...@@ -330,6 +330,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -330,6 +330,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledSoftmaxFwdPrimitive) register_primitive(ScaledSoftmaxFwdPrimitive)
...@@ -400,6 +405,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -400,6 +405,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledSoftmaxBwdPrimitive) register_primitive(ScaledSoftmaxBwdPrimitive)
...@@ -525,6 +535,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -525,6 +535,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...1, ...2 -> ...1"
register_primitive(ScaledMaskedSoftmaxFwdPrimitive) register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
...@@ -596,6 +611,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -596,6 +611,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledMaskedSoftmaxBwdPrimitive) register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
...@@ -682,6 +702,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -682,6 +702,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
result_infos, result_infos,
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
...@@ -761,6 +786,11 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -761,6 +786,11 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
result_infos, result_infos,
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......
...@@ -16,13 +16,33 @@ from typing import Tuple, Dict ...@@ -16,13 +16,33 @@ from typing import Tuple, Dict
from functools import reduce from functools import reduce
import operator import operator
from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["ScalingMode"] __all__ = ["QuantizeShardyRules", "ScalingMode"]
@dataclass
class QuantizeShardyRules:
"""Information necessary to shard scale tensors with Shardy.
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
"""
input_spec: Tuple[str]
rowwise_rule: Tuple[str]
colwise_rule: Tuple[str]
factor_sizes: Dict[str, int]
class ScalingModeMetadataImpl(ABC): class ScalingModeMetadataImpl(ABC):
...@@ -59,6 +79,21 @@ class ScalingModeMetadataImpl(ABC): ...@@ -59,6 +79,21 @@ class ScalingModeMetadataImpl(ABC):
The shape for scale tensors The shape for scale tensors
""" """
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode. """Implementation for delayed scaling mode.
...@@ -95,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -95,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
del data_shape, is_colwise del data_shape, is_colwise
return (1,) return (1,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode. """Implementation for block scaling mode.
...@@ -217,6 +269,45 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -217,6 +269,45 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape) return (*first_dim_scale_shape, *last_dim_scale_shape)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
input_spec = [f"x{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var = unique_var
colwise_var = f"{unique_var}_"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32},
)
@dataclass(frozen=True) @dataclass(frozen=True)
@register_pytree_node_class @register_pytree_node_class
...@@ -290,6 +381,20 @@ class ScalingMode(Enum): ...@@ -290,6 +381,20 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def __eq__(self, other): def __eq__(self, other):
"""Compare this scaling mode with another. """Compare this scaling mode with another.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment