test_distributed_custom_ops.py 12.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import numpy as np
from functools import partial

import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import NamedSharding

from utils import is_devices_enough
from sharding_configs import *
from custom_ops_helper import *
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType

configs = ShardingConfigs()   # default device count is len(jax.devices())
helper = CustomOpsTestHelper()

@pytest.mark.skipif(not helper.use_custom_partitioning(),
                    reason='TE/JAX version does not support sharding with ' + \
                           'jax.experimental.custom_partitioning.')
@pytest.mark.skipif(not is_devices_enough(configs.device_count), reason='Num of GPU is not enough')
class TestCustomPartitioningOpsGenerator:

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.layernorm_refs)
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
    def test_layernorm(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                       zero_centered_gamma):
        epsilon = 1e-6

        custom_func = partial(helper.custom_layernorm,
                              zero_centered_gamma=zero_centered_gamma,
                              epsilon=epsilon,
                              sharding_type=sharding_type)
        
        reference_func = partial(helper.reference_layernorm,
                                 zero_centered_gamma=zero_centered_gamma,
                                 epsilon=epsilon)

        batch_size, _, num_heads, head_dim = helper.qkv_shape
        hidden_size = num_heads*head_dim
        input_shape = (batch_size, hidden_size)
        other_shape = (hidden_size, )
        x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=helper.dtype)
        gamma_ = jnp.ones(other_shape, dtype=helper.dtype)
        beta_ = jnp.ones(other_shape, dtype=helper.dtype)

        x_spec, gamma_spec, beta_spec = helper.get_sharding_spec(mesh_names, sharding_type)
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
        with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
            x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
            gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
            beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec))
            helper.compare_ops(
                custom_func, reference_func, collective_ref,
                x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=helper.dtype,
                in_shardings=[x_spec, gamma_spec, beta_spec],
                out_shardings=(None, (x_spec, gamma_spec, beta_spec))
            )

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.layernorm_refs)
    def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref):
        epsilon = 1e-6
        custom_func = partial(helper.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type)
        reference_func = partial(helper.reference_rmsnorm, epsilon=epsilon)

        batch_size, _, num_heads, head_dim = helper.qkv_shape
        hidden_size = num_heads*head_dim
        input_shape = (batch_size, hidden_size)
        other_shape = (hidden_size, )
        x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=helper.dtype)
        gamma_ = jnp.ones(other_shape, dtype=helper.dtype)

        x_spec, gamma_spec = helper.get_sharding_spec(mesh_names, sharding_type)
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
        with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
            x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
            gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
            helper.compare_ops(
                custom_func, reference_func, collective_ref,
                x_, gamma_, grad_args=(0, 1), dtype=helper.dtype,
                in_shardings=[x_spec, gamma_spec],
                out_shardings=(None, (x_spec, gamma_spec))
            )

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.softmax_refs)
    @pytest.mark.parametrize('softmax_type', configs.softmax_types)
    def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                     softmax_type):
        batch_size, seq_len, num_heads, head_dim = helper.qkv_shape
        scale_factor = 1./jnp.sqrt(head_dim)
        
        custom_func = partial(helper.custom_softmax,
                              scale_factor=scale_factor,
                              softmax_type=softmax_type,
                              sharding_type=sharding_type)
        reference_func = partial(helper.reference_softmax,
                                 scale_factor=scale_factor,
                                 softmax_type=softmax_type)

        input_size = (batch_size, num_heads, seq_len, seq_len)
        x_ = random.normal(random.PRNGKey(1124), input_size, dtype=helper.dtype)

        pad_len = int(seq_len * helper.pad_ratio)
        valid_len = seq_len - pad_len
        tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype),
                                  jnp.zeros((batch_size, pad_len), dtype=helper.dtype)),
                                 axis=-1)
        mask_ = helper.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK)

        x_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type)
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
        with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
            x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
            mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
            helper.compare_ops(
                custom_func, reference_func, collective_ref,
                (0), x_, mask_, grad_args=(0), dtype=helper.dtype,
                in_shardings=[x_spec, mask_spec],
                out_shardings=(None, (x_spec))
            )

    @pytest.mark.parametrize(
        'mesh_shape, mesh_names, sharding_type, attn_bias_type, collective_ref',
        configs.self_attn_refs)
    @pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types)
    def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                             attn_bias_type, attn_mask_type, backend):
        batch_size, seq_len, num_heads, head_dim = helper.qkv_shape
        helper.check_fused_attn_inputs(seq_len, seq_len, head_dim,
                                       helper.pad_ratio, helper.dropout_prob,
                                       attn_bias_type, attn_mask_type, backend)

        dropout_rng = random.PRNGKey(91023051)
        split_rng = random.split(dropout_rng, configs.device_count)
        scale_factor = 1./jnp.sqrt(head_dim)

        custom_func = partial(helper.custom_self_fused_attn,
                              rng_key=split_rng,
                              dropout_prob=helper.dropout_prob,
                              attn_bias_type=attn_bias_type,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scale_factor,
                              sharding_type=sharding_type)
        reference_func = partial(helper.reference_self_fused_attn,
                                 rng_key=dropout_rng,
                                 dropout_prob=helper.dropout_prob,
                                 attn_bias_type=attn_bias_type,
                                 attn_mask_type=attn_mask_type,
                                 scaling_factor=scale_factor)

        key = random.PRNGKey(1124)
        subkeys = random.split(key, 2)

        qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim)
        qkv_ = random.normal(subkeys[0], qkv_shape, dtype=helper.dtype)
        bias_shape = (1, num_heads, seq_len, seq_len)
        bias_ = random.normal(subkeys[1], bias_shape, dtype=helper.dtype)
        
        pad_len = int(seq_len * helper.pad_ratio)
        valid_len = seq_len - pad_len
        tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype),
                                  jnp.zeros((batch_size, pad_len), dtype=helper.dtype)),
                                 axis=-1)
        mask_ = helper.make_mask(tokens, tokens, attn_mask_type)

        qkv_spec, bias_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type)
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
        with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
            qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec))
            bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec))
            mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
            helper.compare_ops(
                custom_func, reference_func, collective_ref,
                qkv_, bias_, mask_, grad_args=(0, 1), dtype=helper.dtype,
                in_shardings=[qkv_spec, bias_spec, mask_spec],
                out_shardings=(None, (qkv_spec, bias_spec))
            )

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.cross_attn_refs)
    @pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types)
    def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                              attn_mask_type, backend):
        batch_size, seq_len, num_heads, head_dim = helper.qkv_shape
        helper.check_fused_attn_inputs(seq_len, seq_len, head_dim,
                                       helper.pad_ratio, helper.dropout_prob,
                                       AttnBiasType.NO_BIAS, attn_mask_type, backend)
        
        dropout_rng = random.PRNGKey(91023051)
        split_rng = random.split(dropout_rng, configs.device_count)
        scale_factor = 1./jnp.sqrt(head_dim)

        custom_func = partial(helper.custom_cross_fused_attn,
                              rng_key=split_rng,
                              dropout_prob=helper.dropout_prob,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scale_factor,
                              sharding_type=sharding_type)
        reference_func = partial(helper.reference_cross_fused_attn,
                                 rng_key=split_rng,
                                 dropout_prob=helper.dropout_prob,
                                 attn_mask_type=attn_mask_type,
                                 scaling_factor=scale_factor)

        key = random.PRNGKey(1124)
        subkeys = random.split(key, 2)

        q_shape = (batch_size, seq_len, num_heads, head_dim)
        q_ = random.normal(subkeys[0], q_shape, dtype=helper.dtype)
        kv_shape = (batch_size, seq_len, 2, num_heads, head_dim)
        kv_ = random.normal(subkeys[1], kv_shape, dtype=helper.dtype)
        
        pad_len = int(seq_len * helper.pad_ratio)
        valid_len = seq_len - pad_len
        tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype),
                                  jnp.zeros((batch_size, pad_len), dtype=helper.dtype)),
                                 axis=-1)
        mask_ = helper.make_mask(tokens, tokens, attn_mask_type)

        q_spec, kv_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type)
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
        with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
            q_ = jax.device_put(q_, NamedSharding(mesh, q_spec))
            kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec))
            mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
            helper.compare_ops(
                custom_func, reference_func, collective_ref,
                q_, kv_, mask_, grad_args=(0, 1), dtype=helper.dtype,
                in_shardings=[q_spec, kv_spec, mask_spec],
                out_shardings=(None, (q_spec, kv_spec))
            )