test_layer.py 27.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import os
6
7
8
9
10
11
12
from functools import partial

import flax
import jax
import jax.numpy as jnp
import pytest

13
from utils import assert_allclose
14
15
16
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

zlsh80826's avatar
zlsh80826 committed
17
18
19
20
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available

21
22
is_fp8_supported, reason = is_fp8_available()

23

24
25
26
27
28
29
30
31
32
33
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
    """
    Enable fused attention
    """
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


34
35
36
37
38
39
40
41
42
43
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
    output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
    return jnp.mean(output)


def generate_test_rngs():
    data_rng = jax.random.PRNGKey(0)
    init_rng = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
    apply_rng = {'dropout': jax.random.PRNGKey(3)}
    return data_rng, init_rng, apply_rng


def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs):
    layer = layer_cls()
    variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs)
59
    others, params = flax.core.pop(variables, 'params')
60
61
62
63
    del variables
    return layer, params, others


64
65
66
67
68
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
    # To be compatible with both Flax>=0.7.1 or <0.7.1
    # since Flax 0.7.1 removed FrozenDict.
    ref_fd = flax.core.unfreeze(ref_fd)
    test_fd = flax.core.unfreeze(test_fd)
69
70
    for key in ref_fd:
        assert key in test_fd, \
71
            f"{key} not found in test dict {test_fd}"
72
73
        assert isinstance(test_fd[key], type(ref_fd[key])), \
            f"The data type is not match between ref and test " \
74
75
76
            f"dict on {key=}"
        if isinstance(ref_fd[key], dict):
            compare_dict(ref_fd[key], test_fd[key], rtol, atol)
77
78
79
80
81
82
83
84
        else:
            assert_allclose(ref_fd[key],
                            test_fd[key],
                            rtol=rtol,
                            atol=atol,
                            err_msg=f"{key=} is not close")


85
DATA_SHAPE = [(32, 128, 1024), (32, 512, 1024)]    # (batch, seqlen, emb_dim)
86
87
88
89
90
91
92
93
94
95
96
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
_KEY_OF_DROPOUT_RATE = "dropout_rate"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
97
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
98
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
99
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
zlsh80826's avatar
zlsh80826 committed
100
101
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
102
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
103
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
104

zlsh80826's avatar
zlsh80826 committed
105
106
107
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
108
    _KEY_OF_DROPOUT_RATE: 0,
zlsh80826's avatar
zlsh80826 committed
109
}
110
111
112
113
114

ATTRS = [{
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
115
116
117
}, {
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_ZERO_CENTERED_GAMMA: True
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROP_PATH: 0.1
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_FUSE_QKV_PARAMS: False
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
    _KEY_OF_FUSE_MLP_WI: True
139
140
141
142
143
144
}, {
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROPOUT_RATE: 0.8,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
    _KEY_OF_FUSE_MLP_WI: True
145
146
147
148
149
150
151
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
    _KEY_OF_FUSE_MLP_WI: True
zlsh80826's avatar
zlsh80826 committed
152
153
}, {
    _KEY_OF_NUM_HEADS: 8,
154
155
156
157
158
159
160
161
162
163
164
165
    _KEY_OF_NUM_GQA_GROUPS: 4,
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
    _KEY_OF_FUSE_MLP_WI: True
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_FUSE_MLP_WI: True,
166
167
    _KEY_OF_ENABLE_ROPE: True,
    _KEY_OF_ROPE_GROUP_METHOD: "consecutive"
168
169
170
171
172
}, {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_FUSE_MLP_WI: True,
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    _KEY_OF_ENABLE_ROPE: True,
    _KEY_OF_ROPE_GROUP_METHOD: "consecutive"
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_FUSE_MLP_WI: True,
    _KEY_OF_ENABLE_ROPE: True,
    _KEY_OF_ROPE_GROUP_METHOD: "alternate"
}, {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_DROPOUT_RATE: 0.0,
    _KEY_OF_FUSE_MLP_WI: True,
    _KEY_OF_ENABLE_ROPE: True,
    _KEY_OF_ROPE_GROUP_METHOD: "alternate"
189
190
191
192
193
194
195
196
}]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


class TestEncoderLayer:

    @staticmethod
zlsh80826's avatar
zlsh80826 committed
197
    def sync_params(ref, target):
198
        unfreeze_target = flax.core.unfreeze(target)
zlsh80826's avatar
zlsh80826 committed
199
200
201
202
203
        unfreeze_attn_scope = unfreeze_target['attention']
        ref_attn_scope = ref['attention']
        for key in ref_attn_scope.keys():
            unfreeze_attn_scope[key]['kernel'] = \
                ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
204
205
206
207
        unfreeze_target['mlp']['wi_kernel'] = \
            jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
        unfreeze_target['mlp']['wo_kernel'] = \
            ref['mlp']['wo']['kernel']
208
        return ref, unfreeze_target
209
210

    def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
211
212
213
214
215
216
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1

217
218
219
220
221
222
223
224
225
226
227
228
        data_rng, init_rng, apply_rng = generate_test_rngs()
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        ref_masks = (1 - padded_mask,)
        test_masks = (None, padded_mask)    # The second arg of Transformer is encoded tokens.

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
229
                te_layer_attrs['intermediate_dropout'] = v
230
231
232
233
234
235
236
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
237
                            intermediate_dropout_dims=(sequence_dim,),
238
                            layer_type=TransformerLayerType.ENCODER,
239
                            self_attn_mask_type='padding',
240
241
242
243
244
245
246
247
                            dtype=dtype,
                            **te_layer_attrs)

        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
248
        ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
249
250
251
252

        ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
        test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)

253
254
        if attrs[_KEY_OF_DROPOUT_RATE] == 0.:    # Skip elementwise checking for dropout
            assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
255
256
257
258

        del data_rng, init_rng, apply_rng

    def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
259
260
261
262
263
264
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1

265
266
267
268
269
270
271
272
273
274
275
276
        data_rng, init_rng, apply_rng = generate_test_rngs()
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        ref_masks = (1 - padded_mask,)
        test_masks = (None, padded_mask)    # The second arg of Transformer is encoded tokens.

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
277
                te_layer_attrs['intermediate_dropout'] = v
278
279
280
281
282
283
284
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
285
                            intermediate_dropout_dims=(sequence_dim,),
286
                            layer_type=TransformerLayerType.ENCODER,
287
                            self_attn_mask_type='padding',
288
289
290
291
292
293
294
                            dtype=dtype,
                            **te_layer_attrs)
        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
295
        ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
296

Ming-Xu Huang's avatar
Ming-Xu Huang committed
297
        if FP8Helper.is_fp8_enabled():
298
299
300
301
            for _ in range(4):
                _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
                                                 has_aux=False)(inputs, test_masks, test_params,
                                                                test_others, test_layer, apply_rng)
302
                _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
303
304
305
306
307
308
309
310
311
312
313
314
315
                test_others = FP8Helper.update_collections(
                    {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
                test_others = FP8Helper.update_fp8_metas(test_others)
                del tmp_grad, fp8_meta_grad

        grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)

        ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
                                     apply_rng)
        test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
                                       apply_rng)

        def reorganize_test_wgrad(test_wgrad, attrs):
zlsh80826's avatar
zlsh80826 committed
316
317
318
319
            num_heads = attrs.get(_KEY_OF_NUM_HEADS)
            num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
            fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
                       num_heads == num_gqa_groups
320
321

            attn_name = 'attention'
322
            unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
323
324
325
326
327
328
329
330
331
332
            if "output_layernorm" not in attrs:
                unfreeze_test_wgrad['pre_attention_layer_norm'] = {}
                pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
                unfreeze_test_wgrad['pre_attention_layer_norm']['scale'] = \
                    unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
                    unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
                        unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
                    del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
zlsh80826's avatar
zlsh80826 committed
333
334
335
336
337
338

            for key in unfreeze_test_wgrad[attn_name].keys():
                unfreeze_test_wgrad[attn_name][key]['kernel'] = \
                    jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
                        (unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
            unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
            unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
                unfreeze_test_wgrad['mlp']['scale']
            del unfreeze_test_wgrad['mlp']['scale']
            if 'ln_bias' in unfreeze_test_wgrad['mlp']:
                unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
                    unfreeze_test_wgrad['mlp']['ln_bias']
                del unfreeze_test_wgrad['mlp']['ln_bias']
            unfreeze_test_wgrad['mlp']['wi'] = {}
            unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
                jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
                            (unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
            del unfreeze_test_wgrad['mlp']['wi_kernel']
            unfreeze_test_wgrad['mlp']['wo'] = {}
            unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
                unfreeze_test_wgrad['mlp']['wo_kernel']
            del unfreeze_test_wgrad['mlp']['wo_kernel']
356
            return unfreeze_test_wgrad
357

358
359
360
361
362
363
364
365
        if attrs[_KEY_OF_DROPOUT_RATE] == 0.:    # Skip elementwise checking for dropout
            assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
            assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol)    # dgrad

            compare_dict(ref_grads[1],
                         reorganize_test_wgrad(test_grads[1], attrs),
                         rtol=rtol,
                         atol=atol)    # wgrad
366
367
368
369
370
371
372
373
374
375

        del data_rng, init_rng, apply_rng

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

376
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
        FP8Helper.finalize()

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

393
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
394
395
396
397
398
399
400
401
402
403
404
405
406
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
        FP8Helper.finalize()


class TestDecoderLayer:

    @staticmethod
zlsh80826's avatar
zlsh80826 committed
407
    def sync_params(ref, target):
408
        unfreeze_target = flax.core.unfreeze(target)
zlsh80826's avatar
zlsh80826 committed
409
410
411
412
413
414
        for scope in ['self_attention', 'encoder_decoder_attention']:
            unfreeze_scope = unfreeze_target[scope]
            ref_scope = ref[scope]
            for key in unfreeze_scope.keys():
                unfreeze_scope[key]['kernel'] = \
                    ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
415
416
417
418
        unfreeze_target['mlp']['wi_kernel'] = \
            jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
        unfreeze_target['mlp']['wo_kernel'] = \
            ref['mlp']['wo']['kernel']
419
        return ref, unfreeze_target
420
421

    def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
422
423
424
425
426
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1
427

428
        data_rng, init_rng, apply_rng = generate_test_rngs()
429
430
431
432
433
434
435
436
437
438
439
440
441
        inputs = (jax.random.normal(data_rng, data_shape,
                                    dtype), jax.random.normal(data_rng, data_shape, dtype))

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
        ref_masks = (1 - causal_mask, 1 - padded_mask)
        test_masks = (causal_mask, padded_mask)

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
442
                te_layer_attrs['intermediate_dropout'] = v
443
444
445
446
447
448
449
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
450
                            intermediate_dropout_dims=(sequence_dim,),
451
452
453
454
455
456
457
458
                            layer_type=TransformerLayerType.DECODER,
                            dtype=dtype,
                            **te_layer_attrs)
        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
459
        ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
460
461
462
463

        ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
        test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)

464
465
        if attrs[_KEY_OF_DROPOUT_RATE] == 0.:    # Skip elementwise checking for dropout
            assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
466
467
468
469

        del data_rng, init_rng, apply_rng

    def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
470
471
472
473
474
        transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
        sequence_dim = 0 if transpose_batch_sequence else 1
475

476
        data_rng, init_rng, apply_rng = generate_test_rngs()
477
478
479
480
481
482
483
484
485
486
487
488
489
        inputs = (jax.random.normal(data_rng, data_shape,
                                    dtype), jax.random.normal(data_rng, data_shape, dtype))

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
        ref_masks = (1 - causal_mask, 1 - padded_mask)
        test_masks = (causal_mask, padded_mask)

        te_layer_attrs = {}
        for k, v in attrs.items():
            if k == 'dropout_rate':
                te_layer_attrs['attention_dropout'] = v
                te_layer_attrs['hidden_dropout'] = v
490
                te_layer_attrs['intermediate_dropout'] = v
491
492
493
494
495
496
497
            elif k == 'fuse_mlp_wi':
                continue
            else:
                te_layer_attrs[k] = v
        ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
        layer_cls = partial(TransformerLayer,
                            hidden_dropout_dims=(sequence_dim,),
498
                            intermediate_dropout_dims=(sequence_dim,),
499
500
501
502
503
504
505
506
                            layer_type=TransformerLayerType.DECODER,
                            dtype=dtype,
                            **te_layer_attrs)
        ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
                                                           ref_masks)
        test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
                                                              test_masks)

zlsh80826's avatar
zlsh80826 committed
507
        ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
508

Ming-Xu Huang's avatar
Ming-Xu Huang committed
509
        if FP8Helper.is_fp8_enabled():
510
511
512
513
            for _ in range(4):
                _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
                                                 has_aux=False)(inputs, test_masks, test_params,
                                                                test_others, test_layer, apply_rng)
514
                _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
515
516
517
518
519
520
521
522
523
524
525
526
527
                test_others = FP8Helper.update_collections(
                    {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
                test_others = FP8Helper.update_fp8_metas(test_others)
                del tmp_grad, fp8_meta_grad

        grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)

        ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
                                     apply_rng)
        test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
                                       apply_rng)

        def reorganize_test_wgrad(test_wgrad, attrs):
zlsh80826's avatar
zlsh80826 committed
528
529
530
531
            num_heads = attrs.get(_KEY_OF_NUM_HEADS)
            num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
            fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
                       num_heads == num_gqa_groups
532

533
            unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
534
            if "output_layernorm" not in attrs:
zlsh80826's avatar
zlsh80826 committed
535
                attn_name = 'self_attention'
536
537
538
539
540
541
542
543
544
545
                unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
                pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
                unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
                    unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
                if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
                    unfreeze_test_wgrad['pre_self_attention_layer_norm']['ln_bias'] = \
                        unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
                    del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']

zlsh80826's avatar
zlsh80826 committed
546
547
548
549
550
            for scope in ['self_attention', 'encoder_decoder_attention']:
                for key in unfreeze_test_wgrad[scope].keys():
                    unfreeze_test_wgrad[scope][key]['kernel'] = \
                        jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
                            (unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

            unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
            unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
                unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
            del unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
            if 'ln_bias' in unfreeze_test_wgrad['encoder_decoder_attention']['query']:
                unfreeze_test_wgrad['pre_cross_attention_layer_norm']['ln_bias'] = \
                    unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
                del unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
            unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
            unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
                unfreeze_test_wgrad['mlp']['scale']
            del unfreeze_test_wgrad['mlp']['scale']
            if 'ln_bias' in unfreeze_test_wgrad['mlp']:
                unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
                    unfreeze_test_wgrad['mlp']['ln_bias']
                del unfreeze_test_wgrad['mlp']['ln_bias']
            unfreeze_test_wgrad['mlp']['wi'] = {}
            unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
                jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
                            (unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
            del unfreeze_test_wgrad['mlp']['wi_kernel']
            unfreeze_test_wgrad['mlp']['wo'] = {}
            unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
                unfreeze_test_wgrad['mlp']['wo_kernel']
            del unfreeze_test_wgrad['mlp']['wo_kernel']
577
            return unfreeze_test_wgrad
578

579
580
581
582
583
584
585
        if attrs[_KEY_OF_DROPOUT_RATE] == 0.:    # Skip elementwise checking for dropout
            assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
            assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol)    # dgrad
            compare_dict(ref_grads[1],
                         reorganize_test_wgrad(test_grads[1], attrs),
                         rtol=rtol,
                         atol=atol)    # wgrad
586
587
588
589
590
591
592
593
594
595

        del data_rng, init_rng, apply_rng

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)

596
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
        FP8Helper.finalize()

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs):
        FP8Helper.finalize()    # Ensure FP8 disabled.
611
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=3e-04)
612

613
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
614
615
616
617
618
619
620
621
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    @pytest.mark.parametrize('attrs', ATTRS)
    def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
        FP8Helper.initialize(fp8_format=fp8_format)
        self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
        FP8Helper.finalize()