test_sharding.py 16.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import jax
import numpy as np
import pytest
from jax.experimental import maps

10
from utils import is_devices_enough
11
from transformer_engine.jax.flax import extend_logical_axis_rules
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType


def _get_sharding_resource(mesh_names, sharding_type):
    dp_r = None
    tp_r = None

    if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
        dp_r = mesh_names[0]

    if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW):
        tp_r = mesh_names[0]

    if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
        tp_r = mesh_names[1]
    return ShardingResource(dp_r, tp_r)


DEVICE_COUNT = 4
MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP_COL),
               ((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
               ((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)]

Ming-Xu Huang's avatar
Ming-Xu Huang committed
41
42
43
44
45
46
47
LOGICAL_RULES = [
    [(('a1', None), ('a2', 'ma2')), False],
    [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
    [(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False],
    [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
    [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
]
48
49
50
51
52
53
54
55
SRS = [
    ShardingResource(),
    ShardingResource('data', None),
    ShardingResource(None, 'model'),
    ShardingResource('data', 'model')
]


56
57
58
59
60
61
62
63
64
65
66
67
68
class TestShardingSideAPI:

    @pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
    @pytest.mark.parametrize('sr', SRS)
    def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
        with global_shard_guard(sr):
            try:
                target_te_rules = extend_logical_axis_rules(tuple())
                extended_rules = extend_logical_axis_rules(base_rules)
                assert extended_rules == (*base_rules, *target_te_rules)
                assert not need_assert
            except AssertionError as ae:
                assert need_assert, f"{ae.args}"
69
70
71
72
73


class TestGeneralFunc:

    @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
74
    @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def test_infer_major_sharding_type(
            self,
            mesh_shape,    # pylint: disable=unused-argument
            mesh_names,
            sharding_type):
        devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
        with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
            with maps.Mesh(devices, mesh_names):
                assert infer_major_sharding_type() is sharding_type.value[0]

    @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
    def test_is_dp_enabled(
            self,
            mesh_shape,    # pylint: disable=unused-argument
            mesh_names,    # pylint: disable=unused-argument
            sharding_type):
        if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
            assert is_dp_enabled(sharding_type.value[0])
        else:
            assert not is_dp_enabled(sharding_type.value[0])

    @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
    def test_is_tp_enabled(
            self,
            mesh_shape,    # pylint: disable=unused-argument
            mesh_names,    # pylint: disable=unused-argument
            sharding_type):
        if sharding_type is ShardingType.DP:
            assert not is_tp_enabled(sharding_type.value[0])
        else:
            assert is_tp_enabled(sharding_type.value[0])


class TestShardingMetaGenerator:

    BATCH_AXIS_NAME = 'batch'
    MODEL_AXIS_NAME = 'model'

    @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
114
    @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):

        def stack_axes_meta(mapping):
            return tuple(mapping for _ in range(num_of_fp8_meta))

        def get_ref_sm():
            if sharding_type == ShardingType.DP:
                return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
                                    {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, (),
                                    ())

            if sharding_type == ShardingType.TP_COL:
                return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
                                    {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
                                    ())

            if sharding_type == ShardingType.TP_ROW:
                return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
                                    {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
                                    ())

            if sharding_type == ShardingType.DP_TP_COL:
                return ShardingMeta(
                    stack_axes_meta({}), stack_axes_meta({}), {
                        TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
                        TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
                    }, (), ())

            if sharding_type == ShardingType.DP_TP_ROW:
                return ShardingMeta(
                    stack_axes_meta({}), stack_axes_meta({}), {
                        TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
                        TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
                    }, (), ())
            return None

        devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
        with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
            with maps.Mesh(devices, mesh_names):
                test_sm = get_fp8_meta_sharding_meta(
                    sharding_type,
                    num_of_fp8_meta,
                    dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
                    tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
                assert test_sm == get_ref_sm()

    @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
    @pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
                                                  ((128, 64, 512), (512, 256))])
    @pytest.mark.parametrize('batch_dim_of_a', [0, 1])
165
    @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
        model_dim_of_a = len(a_shape) - 1
        model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
        contracting_dims = ((-1,), (0,))

        def get_ref_sm():
            out_shape = (*a_shape[:min(contracting_dims[0])],
                         *b_shape[max(contracting_dims[1]) + 1:])
            if sharding_type == ShardingType.DP:
                a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], -1,
                               *a_shape[batch_dim_of_a + 1:])
                return ShardingMeta(({
                    batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
                }, {}), ({
                    batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
                }), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
                                    [a_new_shape, b_shape], [out_shape])

            if sharding_type == ShardingType.TP_COL:
                b_new_shape = (b_shape[0], mesh_shape[0], b_shape[1] // mesh_shape[0])
                return ShardingMeta(({}, {
                    1: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }), ({
                    len(out_shape) - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
                                    [a_shape, b_new_shape], [out_shape])

            if sharding_type == ShardingType.TP_ROW:
                a_new_shape = (*a_shape[:-1], mesh_shape[0], a_shape[-1] // mesh_shape[0])
                b_new_shape = (mesh_shape[0], b_shape[0] // mesh_shape[0], b_shape[1])
                return ShardingMeta(({
                    len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }, {
                    0: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }), ({}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
                                    [a_new_shape, b_new_shape], [out_shape])

            if sharding_type == ShardingType.DP_TP_COL:
                a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
                               a_shape[batch_dim_of_a] // mesh_shape[0],
                               *a_shape[batch_dim_of_a + 1:])
                b_new_shape = (b_shape[0], mesh_shape[1], b_shape[1] // mesh_shape[1])
                return ShardingMeta(
                    ({
                        batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
                    }, {
                        1: TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }), ({
                        batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
                        len(out_shape): TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }), {
                        TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
                        TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
                    }, [a_new_shape, b_new_shape], [out_shape])

            if sharding_type == ShardingType.DP_TP_ROW:
                a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
                               a_shape[batch_dim_of_a] // mesh_shape[0],
                               *a_shape[batch_dim_of_a + 1:-1], mesh_shape[1],
                               a_shape[-1] // mesh_shape[1])
                b_new_shape = (mesh_shape[1], b_shape[0] // mesh_shape[1], b_shape[1])
                return ShardingMeta(
                    ({
                        batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
                        len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }, {
                        0: TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }), ({
                        batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
                    }), {
                        TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
                        TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
                    }, [a_new_shape, b_new_shape], [out_shape])
            return None

        devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
        with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
            with maps.Mesh(devices, mesh_names):
                test_sm = get_dot_sharding_meta(
                    sharding_type,
                    a_shape,
                    b_shape,
                    batch_dim_of_a,
                    model_dim_of_a,
                    model_dim_of_b,
                    contracting_dims,
                    dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
                    tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
                assert test_sm == get_ref_sm()

    @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
    @pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
    @pytest.mark.parametrize('other_shape', [(256,), (512,)])
    @pytest.mark.parametrize('batch_dim', [0, 1])
260
    @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
                         batch_dim):

        def get_ref_sm():
            need_assert = True
            ref_sharding_meta = None
            if input_shape[-1] != other_shape[0]:
                need_assert = True
                ref_sharding_meta = None
            elif sharding_type is (ShardingType.DP_TP_COL, ShardingType.DP):
                need_assert = False
                input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
                                   *input_shape[batch_dim + 1:])
                ref_sharding_meta = ShardingMeta(({
                    batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
                }, {}), ({
                    batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
                }), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
                                                 [input_new_shape, other_shape], [input_shape])
            elif sharding_type is ShardingType.TP_COL:
                need_assert = False
                ref_sharding_meta = ShardingMeta(({}, {}), ({}), {}, [input_shape, other_shape],
                                                 [input_shape])
            elif sharding_type is ShardingType.TP_ROW:
                need_assert = False
                input_new_shape = (*input_shape[:-1], mesh_shape[0], -1)
                other_new_shape = (mesh_shape[0], -1)

                ref_sharding_meta = ShardingMeta(({
                    len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }, {
                    0: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }), ({
                    len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
                }), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
                                                 [input_new_shape, other_new_shape], [input_shape])
            elif sharding_type is ShardingType.DP_TP_ROW:
                need_assert = False
                input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
                                   *input_shape[batch_dim + 1:-1], mesh_shape[1],
                                   input_shape[-1] // mesh_shape[1])
                other_new_shape = (mesh_shape[0], -1)

                ref_sharding_meta = ShardingMeta(
                    ({
                        batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
                        len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }, {
                        0: TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }), ({
                        batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
                        len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
                    }), {
                        TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
                        TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
                    }, [input_new_shape, other_new_shape], [input_shape])

            return ref_sharding_meta, need_assert

        devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
        with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
            with maps.Mesh(devices, mesh_names):
                ref_sm, need_assert = get_ref_sm()
                try:
                    test_sm = get_elementwise_sharding_meta(
                        sharding_type,
                        input_shape,
                        other_shape,
                        batch_dim,
                        dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
                        tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
                    assert not need_assert
                    assert test_sm == ref_sm
                except (NotImplementedError, AssertionError) as e:
                    assert need_assert, f"{e.args}"