test_sharding.py 1.37 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

import pytest

7
from transformer_engine.jax.flax import extend_logical_axis_rules
8
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
9

Ming-Xu Huang's avatar
Ming-Xu Huang committed
10
11
12
13
14
15
16
LOGICAL_RULES = [
    [(('a1', None), ('a2', 'ma2')), False],
    [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
    [(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False],
    [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
    [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
]
17
18
19
20
21
22

MeshS = [
    MeshResource(),
    MeshResource('data', None),
    MeshResource(None, 'model'),
    MeshResource('data', 'model')
23
24
25
]


26
27
28
class TestShardingSideAPI:

    @pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
29
    @pytest.mark.parametrize('sr', MeshS)
30
31
32
33
34
35
36
37
38
    def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
        with global_shard_guard(sr):
            try:
                target_te_rules = extend_logical_axis_rules(tuple())
                extended_rules = extend_logical_axis_rules(base_rules)
                assert extended_rules == (*base_rules, *target_te_rules)
                assert not need_assert
            except AssertionError as ae:
                assert need_assert, f"{ae.args}"