test_distributed_custom_ops.py 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import numpy as np
from functools import partial

import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import NamedSharding

from utils import is_devices_enough
14
15
from distributed_configs_helper import *
from distributed_ops_helper import *
16
17
18
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType

19
20
configs = DistributedConfigsHelper()   # default device count is len(jax.devices())
ops = DistributedOpsHelper()                     # default data type is jnp.float16
21

22
23
24
@pytest.mark.skipif(not is_devices_enough(configs.device_count),
                    reason='Insufficient number of GPUs, need at least 2.')
@pytest.mark.skipif(not ops.use_custom_partitioning(),
25
26
27
28
29
30
31
32
33
34
35
                    reason='TE/JAX version does not support sharding with ' + \
                           'jax.experimental.custom_partitioning.')
class TestCustomPartitioningOpsGenerator:

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.layernorm_refs)
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
    def test_layernorm(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                       zero_centered_gamma):
        epsilon = 1e-6

36
        custom_func = partial(ops.custom_layernorm,
37
38
39
40
                              zero_centered_gamma=zero_centered_gamma,
                              epsilon=epsilon,
                              sharding_type=sharding_type)
        
41
        reference_func = partial(ops.reference_layernorm,
42
43
44
                                 zero_centered_gamma=zero_centered_gamma,
                                 epsilon=epsilon)

45
        batch_size, _, num_heads, head_dim = ops.qkv_shape
46
47
48
        hidden_size = num_heads*head_dim
        input_shape = (batch_size, hidden_size)
        other_shape = (hidden_size, )
49
50
51
        x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
        gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
        beta_ = jnp.ones(other_shape, dtype=ops.dtype)
52

53
        x_spec, gamma_spec, beta_spec = ops.get_sharding_spec(mesh_names, sharding_type)
54
55
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
56
        with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
57
58
59
            x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
            gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
            beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec))
60
            ops.compare_ops(
61
                custom_func, reference_func, collective_ref,
62
                x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=ops.dtype,
63
64
65
66
67
68
69
70
                in_shardings=[x_spec, gamma_spec, beta_spec],
                out_shardings=(None, (x_spec, gamma_spec, beta_spec))
            )

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.layernorm_refs)
    def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref):
        epsilon = 1e-6
71
72
        custom_func = partial(ops.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type)
        reference_func = partial(ops.reference_rmsnorm, epsilon=epsilon)
73

74
        batch_size, _, num_heads, head_dim = ops.qkv_shape
75
76
77
        hidden_size = num_heads*head_dim
        input_shape = (batch_size, hidden_size)
        other_shape = (hidden_size, )
78
79
        x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
        gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
80

81
        x_spec, gamma_spec = ops.get_sharding_spec(mesh_names, sharding_type)
82
83
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
84
        with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
85
86
            x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
            gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
87
            ops.compare_ops(
88
                custom_func, reference_func, collective_ref,
89
                x_, gamma_, grad_args=(0, 1), dtype=ops.dtype,
90
91
92
93
94
95
96
97
98
                in_shardings=[x_spec, gamma_spec],
                out_shardings=(None, (x_spec, gamma_spec))
            )

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.softmax_refs)
    @pytest.mark.parametrize('softmax_type', configs.softmax_types)
    def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                     softmax_type):
99
        batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
100
101
        scale_factor = 1./jnp.sqrt(head_dim)
        
102
        custom_func = partial(ops.custom_softmax,
103
104
105
                              scale_factor=scale_factor,
                              softmax_type=softmax_type,
                              sharding_type=sharding_type)
106
        reference_func = partial(ops.reference_softmax,
107
108
109
110
                                 scale_factor=scale_factor,
                                 softmax_type=softmax_type)

        input_size = (batch_size, num_heads, seq_len, seq_len)
111
        x_ = random.normal(random.PRNGKey(1124), input_size, dtype=ops.dtype)
112

113
        pad_len = int(seq_len * ops.pad_ratio)
114
        valid_len = seq_len - pad_len
115
116
        tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
                                  jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
117
                                 axis=-1)
118
        mask_ = ops.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK)
119

120
        x_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
121
122
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
123
        with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
124
125
            x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
            mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
126
            ops.compare_ops(
127
                custom_func, reference_func, collective_ref,
128
                (0), x_, mask_, grad_args=(0), dtype=ops.dtype,
129
130
131
132
133
134
135
136
137
138
                in_shardings=[x_spec, mask_spec],
                out_shardings=(None, (x_spec))
            )

    @pytest.mark.parametrize(
        'mesh_shape, mesh_names, sharding_type, attn_bias_type, collective_ref',
        configs.self_attn_refs)
    @pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types)
    def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                             attn_bias_type, attn_mask_type, backend):
139
140
141
        batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
        ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
                                       ops.pad_ratio, ops.dropout_prob,
142
143
144
145
146
147
                                       attn_bias_type, attn_mask_type, backend)

        dropout_rng = random.PRNGKey(91023051)
        split_rng = random.split(dropout_rng, configs.device_count)
        scale_factor = 1./jnp.sqrt(head_dim)

148
        custom_func = partial(ops.custom_self_fused_attn,
149
                              rng_key=split_rng,
150
                              dropout_prob=ops.dropout_prob,
151
152
153
154
                              attn_bias_type=attn_bias_type,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scale_factor,
                              sharding_type=sharding_type)
155
        reference_func = partial(ops.reference_self_fused_attn,
156
                                 rng_key=dropout_rng,
157
                                 dropout_prob=ops.dropout_prob,
158
159
160
161
162
163
164
165
                                 attn_bias_type=attn_bias_type,
                                 attn_mask_type=attn_mask_type,
                                 scaling_factor=scale_factor)

        key = random.PRNGKey(1124)
        subkeys = random.split(key, 2)

        qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim)
166
        qkv_ = random.normal(subkeys[0], qkv_shape, dtype=ops.dtype)
167
        bias_shape = (1, num_heads, seq_len, seq_len)
168
        bias_ = random.normal(subkeys[1], bias_shape, dtype=ops.dtype)
169
        
170
        pad_len = int(seq_len * ops.pad_ratio)
171
        valid_len = seq_len - pad_len
172
173
        tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
                                  jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
174
                                 axis=-1)
175
        mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
176

177
        qkv_spec, bias_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
178
179
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
180
        with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
181
182
183
            qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec))
            bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec))
            mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
184
            ops.compare_ops(
185
                custom_func, reference_func, collective_ref,
186
                qkv_, bias_, mask_, grad_args=(0, 1), dtype=ops.dtype,
187
188
189
190
191
192
193
194
195
                in_shardings=[qkv_spec, bias_spec, mask_spec],
                out_shardings=(None, (qkv_spec, bias_spec))
            )

    @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
                             configs.cross_attn_refs)
    @pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types)
    def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
                              attn_mask_type, backend):
196
197
198
        batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
        ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
                                       ops.pad_ratio, ops.dropout_prob,
199
200
201
202
203
204
                                       AttnBiasType.NO_BIAS, attn_mask_type, backend)
        
        dropout_rng = random.PRNGKey(91023051)
        split_rng = random.split(dropout_rng, configs.device_count)
        scale_factor = 1./jnp.sqrt(head_dim)

205
        custom_func = partial(ops.custom_cross_fused_attn,
206
                              rng_key=split_rng,
207
                              dropout_prob=ops.dropout_prob,
208
209
210
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scale_factor,
                              sharding_type=sharding_type)
211
        reference_func = partial(ops.reference_cross_fused_attn,
212
                                 rng_key=split_rng,
213
                                 dropout_prob=ops.dropout_prob,
214
215
216
217
218
219
220
                                 attn_mask_type=attn_mask_type,
                                 scaling_factor=scale_factor)

        key = random.PRNGKey(1124)
        subkeys = random.split(key, 2)

        q_shape = (batch_size, seq_len, num_heads, head_dim)
221
        q_ = random.normal(subkeys[0], q_shape, dtype=ops.dtype)
222
        kv_shape = (batch_size, seq_len, 2, num_heads, head_dim)
223
        kv_ = random.normal(subkeys[1], kv_shape, dtype=ops.dtype)
224
        
225
        pad_len = int(seq_len * ops.pad_ratio)
226
        valid_len = seq_len - pad_len
227
228
        tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
                                  jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
229
                                 axis=-1)
230
        mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
231

232
        q_spec, kv_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
233
234
        devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
        mesh = jax.sharding.Mesh(devices, mesh_names)
235
        with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
236
237
238
            q_ = jax.device_put(q_, NamedSharding(mesh, q_spec))
            kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec))
            mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
239
            ops.compare_ops(
240
                custom_func, reference_func, collective_ref,
241
                q_, kv_, mask_, grad_args=(0, 1), dtype=ops.dtype,
242
243
244
                in_shardings=[q_spec, kv_spec, mask_spec],
                out_shardings=(None, (q_spec, kv_spec))
            )