test_layer.py 26.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

from functools import partial

import flax
import jax
import jax.numpy as jnp
import pytest

12
from utils import assert_allclose
13
14
15
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

zlsh80826's avatar
zlsh80826 committed
16
17
18
19
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available

20
21
is_fp8_supported, reason = is_fp8_available()

22

23
24
25
26
27
28
29
30
31
32
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
    output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
    return jnp.mean(output)


def generate_test_rngs():
    data_rng = jax.random.PRNGKey(0)
    init_rng = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
    apply_rng = {'dropout': jax.random.PRNGKey(3)}
    return data_rng, init_rng, apply_rng


def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs):
    layer = layer_cls()
    variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs)
48
    others, params = flax.core.pop(variables, 'params')
49
50
51
52
    del variables
    return layer, params, others


53
54
55
56
57
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
    # To be compatible with both Flax>=0.7.1 or <0.7.1
    # since Flax 0.7.1 removed FrozenDict.
    ref_fd = flax.core.unfreeze(ref_fd)
    test_fd = flax.core.unfreeze(test_fd)
58
59
    for key in ref_fd:
        assert key in test_fd, \
60
            f"{key} not found in test dict {test_fd}"
61
62
        assert isinstance(test_fd[key], type(ref_fd[key])), \
            f"The data type is not match between ref and test " \
63
64
65
            f"dict on {key=}"
        if isinstance(ref_fd[key], dict):
            compare_dict(ref_fd[key], test_fd[key], rtol, atol)
66
67
68
69
70
71
72
73
        else:
            assert_allclose(ref_fd[key],
                            test_fd[key],
                            rtol=rtol,
                            atol=atol,
                            err_msg=f"{key=} is not close")


74
DATA_SHAPE = [(32, 128, 1024), (32, 512, 1024)]    # (batch, seqlen, emb_dim)
75
76
77
78
79
80
81
82
83
84
85
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
_KEY_OF_DROPOUT_RATE = "dropout_rate"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
86
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
87
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
88
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
zlsh80826's avatar
zlsh80826 committed
89
90
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
91

zlsh80826's avatar
zlsh80826 committed
92
93
94
95
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
}
96
97
98
99
100

ATTRS = [{
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
101
102
103
}, {
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_ZERO_CENTERED_GAMMA: True
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROP_PATH: 0.1
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_FUSE_QKV_PARAMS: False
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
    _KEY_OF_FUSE_MLP_WI: True
125
126
127
128
129
130
}, {
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROPOUT_RATE: 0.8,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
    _KEY_OF_FUSE_MLP_WI: True
131
132
133
134
135
136
137
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
    _KEY_OF_FUSE_MLP_WI: True
zlsh80826's avatar
zlsh80826 committed
138
139
140
}, {
    _KEY_OF_NUM_HEADS: 8,
    _KEY_OF_NUM_GQA_GROUPS: 4
141
142
143
144
145
146
147
148
}]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


class TestEncoderLayer:

    @staticmethod
zlsh80826's avatar
zlsh80826 committed
149
    def sync_params(ref, target):
150
        unfreeze_target = flax.core.unfreeze(target)
zlsh80826's avatar
zlsh80826 committed
151
152
153
154
155
        unfreeze_attn_scope = unfreeze_target['attention']
        ref_attn_scope = ref['attention']
        for key in ref_attn_scope.keys():
            unfreeze_attn_scope[key]['kernel'] = \
                ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
156
157
158
159
        unfreeze_target['mlp']['wi_kernel'] = \
            jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
        unfreeze_target['mlp']['wo_kernel'] = \
            ref['mlp']['wo']['kernel']
160
        return ref, unfreeze_target
161
162

    def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
163
164
165
166
167
168
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1

169
170
171
172
173
174
175
176
177
178
179
180
        data_rng, init_rng, apply_rng = generate_test_rngs()
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        ref_masks = (1 - padded_mask,)
        test_masks = (None, padded_mask)    # The second arg of Transformer is encoded tokens.

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
181
                te_layer_attrs['intermediate_dropout'] = v
182
183
184
185
186
187
188
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
189
                            intermediate_dropout_dims=(sequence_dim,),
190
                            layer_type=TransformerLayerType.ENCODER,
191
                            self_attn_mask_type='padding',
192
193
194
195
196
197
198
199
                            dtype=dtype,
                            **te_layer_attrs)

        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
200
        ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
201
202
203
204
205
206
207
208
209

        ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
        test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)

        del data_rng, init_rng, apply_rng

    def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
210
211
212
213
214
215
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1

216
217
218
219
220
221
222
223
224
225
226
227
        data_rng, init_rng, apply_rng = generate_test_rngs()
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        ref_masks = (1 - padded_mask,)
        test_masks = (None, padded_mask)    # The second arg of Transformer is encoded tokens.

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
228
                te_layer_attrs['intermediate_dropout'] = v
229
230
231
232
233
234
235
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
236
                            intermediate_dropout_dims=(sequence_dim,),
237
                            layer_type=TransformerLayerType.ENCODER,
238
                            self_attn_mask_type='padding',
239
240
241
242
243
244
245
                            dtype=dtype,
                            **te_layer_attrs)
        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
246
        ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
247

Ming-Xu Huang's avatar
Ming-Xu Huang committed
248
        if FP8Helper.is_fp8_enabled():
249
250
251
252
            for _ in range(4):
                _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
                                                 has_aux=False)(inputs, test_masks, test_params,
                                                                test_others, test_layer, apply_rng)
253
                _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                test_others = FP8Helper.update_collections(
                    {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
                test_others = FP8Helper.update_fp8_metas(test_others)
                del tmp_grad, fp8_meta_grad

        grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)

        ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
                                     apply_rng)
        test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
                                       apply_rng)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
        assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol)    # dgrad

        def reorganize_test_wgrad(test_wgrad, attrs):
zlsh80826's avatar
zlsh80826 committed
270
271
272
273
            num_heads = attrs.get(_KEY_OF_NUM_HEADS)
            num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
            fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
                       num_heads == num_gqa_groups
274
275

            attn_name = 'attention'
276
            unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
277
278
279
280
281
282
283
284
285
286
            if "output_layernorm" not in attrs:
                unfreeze_test_wgrad['pre_attention_layer_norm'] = {}
                pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
                unfreeze_test_wgrad['pre_attention_layer_norm']['scale'] = \
                    unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
                    unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
                        unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
                    del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
zlsh80826's avatar
zlsh80826 committed
287
288
289
290
291
292

            for key in unfreeze_test_wgrad[attn_name].keys():
                unfreeze_test_wgrad[attn_name][key]['kernel'] = \
                    jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
                        (unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
            unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
                unfreeze_test_wgrad['mlp']['scale']
            del unfreeze_test_wgrad['mlp']['scale']
            if 'ln_bias' in unfreeze_test_wgrad['mlp']:
                unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
                    unfreeze_test_wgrad['mlp']['ln_bias']
                del unfreeze_test_wgrad['mlp']['ln_bias']
            unfreeze_test_wgrad['mlp']['wi'] = {}
            unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
                jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
                            (unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
            del unfreeze_test_wgrad['mlp']['wi_kernel']
            unfreeze_test_wgrad['mlp']['wo'] = {}
            unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
                unfreeze_test_wgrad['mlp']['wo_kernel']
            del unfreeze_test_wgrad['mlp']['wo_kernel']
310
            return unfreeze_test_wgrad
311

312
313
314
315
        compare_dict(ref_grads[1],
                     reorganize_test_wgrad(test_grads[1], attrs),
                     rtol=rtol,
                     atol=atol)    # wgrad
316
317
318
319
320
321
322
323
324
325

        del data_rng, init_rng, apply_rng

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

326
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
        FP8Helper.finalize()

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

343
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
344
345
346
347
348
349
350
351
352
353
354
355
356
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
        FP8Helper.finalize()


class TestDecoderLayer:

    @staticmethod
zlsh80826's avatar
zlsh80826 committed
357
    def sync_params(ref, target):
358
        unfreeze_target = flax.core.unfreeze(target)
zlsh80826's avatar
zlsh80826 committed
359
360
361
362
363
364
        for scope in ['self_attention', 'encoder_decoder_attention']:
            unfreeze_scope = unfreeze_target[scope]
            ref_scope = ref[scope]
            for key in unfreeze_scope.keys():
                unfreeze_scope[key]['kernel'] = \
                    ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
365
366
367
368
        unfreeze_target['mlp']['wi_kernel'] = \
            jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
        unfreeze_target['mlp']['wo_kernel'] = \
            ref['mlp']['wo']['kernel']
369
        return ref, unfreeze_target
370
371

    def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
372
373
374
375
376
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1
377

378
        data_rng, init_rng, apply_rng = generate_test_rngs()
379
380
381
382
383
384
385
386
387
388
389
390
391
        inputs = (jax.random.normal(data_rng, data_shape,
                                    dtype), jax.random.normal(data_rng, data_shape, dtype))

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
        ref_masks = (1 - causal_mask, 1 - padded_mask)
        test_masks = (causal_mask, padded_mask)

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
392
                te_layer_attrs['intermediate_dropout'] = v
393
394
395
396
397
398
399
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
400
                            intermediate_dropout_dims=(sequence_dim,),
401
402
403
404
405
406
407
408
                            layer_type=TransformerLayerType.DECODER,
                            dtype=dtype,
                            **te_layer_attrs)
        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
409
        ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
410
411
412
413
414
415
416
417
418

        ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
        test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)

        del data_rng, init_rng, apply_rng

    def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
419
420
421
422
423
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1
424

425
        data_rng, init_rng, apply_rng = generate_test_rngs()
426
427
428
429
430
431
432
433
434
435
436
437
438
        inputs = (jax.random.normal(data_rng, data_shape,
                                    dtype), jax.random.normal(data_rng, data_shape, dtype))

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
        ref_masks = (1 - causal_mask, 1 - padded_mask)
        test_masks = (causal_mask, padded_mask)

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
439
                te_layer_attrs['intermediate_dropout'] = v
440
441
442
443
444
445
446
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
447
                            intermediate_dropout_dims=(sequence_dim,),
448
449
450
451
452
453
454
455
                            layer_type=TransformerLayerType.DECODER,
                            dtype=dtype,
                            **te_layer_attrs)
        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
456
        ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
457

Ming-Xu Huang's avatar
Ming-Xu Huang committed
458
        if FP8Helper.is_fp8_enabled():
459
460
461
462
            for _ in range(4):
                _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
                                                 has_aux=False)(inputs, test_masks, test_params,
                                                                test_others, test_layer, apply_rng)
463
                _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
                test_others = FP8Helper.update_collections(
                    {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
                test_others = FP8Helper.update_fp8_metas(test_others)
                del tmp_grad, fp8_meta_grad

        grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)

        ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
                                     apply_rng)
        test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
                                       apply_rng)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
        assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol)    # dgrad

        def reorganize_test_wgrad(test_wgrad, attrs):
zlsh80826's avatar
zlsh80826 committed
480
481
482
483
            num_heads = attrs.get(_KEY_OF_NUM_HEADS)
            num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
            fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
                       num_heads == num_gqa_groups
484

485
            unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
486
            if "output_layernorm" not in attrs:
zlsh80826's avatar
zlsh80826 committed
487
                attn_name = 'self_attention'
488
489
490
491
492
493
494
495
496
497
                unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
                pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
                unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
                    unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
                    unfreeze_test_wgrad['pre_self_attention_layer_norm']['ln_bias'] = \
                        unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
                    del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']

zlsh80826's avatar
zlsh80826 committed
498
499
500
501
502
            for scope in ['self_attention', 'encoder_decoder_attention']:
                for key in unfreeze_test_wgrad[scope].keys():
                    unfreeze_test_wgrad[scope][key]['kernel'] = \
                        jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
                            (unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528

            unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
            unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
                unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
            del unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
            if 'ln_bias' in unfreeze_test_wgrad['encoder_decoder_attention']['query']:
                unfreeze_test_wgrad['pre_cross_attention_layer_norm']['ln_bias'] = \
                    unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
                del unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
            unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
            unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
                unfreeze_test_wgrad['mlp']['scale']
            del unfreeze_test_wgrad['mlp']['scale']
            if 'ln_bias' in unfreeze_test_wgrad['mlp']:
                unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
                    unfreeze_test_wgrad['mlp']['ln_bias']
                del unfreeze_test_wgrad['mlp']['ln_bias']
            unfreeze_test_wgrad['mlp']['wi'] = {}
            unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
                jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
                            (unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
            del unfreeze_test_wgrad['mlp']['wi_kernel']
            unfreeze_test_wgrad['mlp']['wo'] = {}
            unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
                unfreeze_test_wgrad['mlp']['wo_kernel']
            del unfreeze_test_wgrad['mlp']['wo_kernel']
529
            return unfreeze_test_wgrad
530

531
532
533
534
        compare_dict(ref_grads[1],
                     reorganize_test_wgrad(test_grads[1], attrs),
                     rtol=rtol,
                     atol=atol)    # wgrad
535
536
537
538
539
540
541
542
543
544

        del data_rng, init_rng, apply_rng

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

545
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
        FP8Helper.finalize()

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

562
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
563
564
565
566
567
568
569
570
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
        FP8Helper.finalize()