test_sharding.py 1.37 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

import pytest

7
from transformer_engine.jax.flax import extend_logical_axis_rules
8
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
9

Ming-Xu Huang's avatar
Ming-Xu Huang committed
10
LOGICAL_RULES = [
11
12
13
14
15
    [(("a1", None), ("a2", "ma2")), False],
    [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
    [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
    [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
    [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
Ming-Xu Huang's avatar
Ming-Xu Huang committed
16
]
17
18
19

MeshS = [
    MeshResource(),
20
21
22
    MeshResource("data", None),
    MeshResource(None, "model"),
    MeshResource("data", "model"),
23
24
25
]


26
27
class TestShardingSideAPI:

28
29
    @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
    @pytest.mark.parametrize("sr", MeshS)
30
31
32
33
34
35
36
37
38
    def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
        with global_shard_guard(sr):
            try:
                target_te_rules = extend_logical_axis_rules(tuple())
                extended_rules = extend_logical_axis_rules(base_rules)
                assert extended_rules == (*base_rules, *target_te_rules)
                assert not need_assert
            except AssertionError as ae:
                assert need_assert, f"{ae.args}"