test_layer.py 15.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
from utils import assert_allclose, assert_tree_like_allclose, sync_params_values
15
16
17
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

zlsh80826's avatar
zlsh80826 committed
18
19
20
21
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available

22
23
is_fp8_supported, reason = is_fp8_available()

24

25
@pytest.fixture(autouse=True, scope='function')
26
def enable_fused_attn():
27
    """Enable fused attention"""
28
29
30
31
32
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


33
34
35
36
DATA_SHAPE = [    # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id='32-128-1024'),
    pytest.param((32, 512, 1024), id='32-512-1024'),
]
37
38
39
40
41
42
43
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
44
45
46
47
48
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
49
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
50
51
52
53
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
54
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
55
56
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
57
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
58
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
59
60
61
62
63
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
64

zlsh80826's avatar
zlsh80826 committed
65
66
67
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
68
69
70
71
72
    _KEY_OF_HIDDEN_DROPOUT: 0,
    _KEY_OF_ATTENTION_DROPOUT: 0,
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
zlsh80826's avatar
zlsh80826 committed
73
}
74

75
ATTRS = [{}, {
76
77
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
78
79
    _KEY_OF_ZERO_CENTERED_GAMMA: True,
    _KEY_OF_LAYERNORM_EPS: 1e-2,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROP_PATH: 0.1
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_FUSE_QKV_PARAMS: False
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
98
    _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
99
100
101
}, {
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
102
103
104
105
    _KEY_OF_HIDDEN_DROPOUT: 0.8,
    _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
    _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
    _KEY_OF_USE_BIAS: True,
106
107
108
109
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
110
    _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
zlsh80826's avatar
zlsh80826 committed
111
112
}, {
    _KEY_OF_NUM_HEADS: 8,
113
114
115
    _KEY_OF_NUM_GQA_GROUPS: 4,
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
116
117
    _KEY_OF_MLP_ACTIVATIONS: ('gelu',),
    _KEY_OF_USE_BIAS: True,
118
119
120
121
122
123
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
}, {
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
124
125
    _KEY_OF_HIDDEN_DROPOUT: 0.8,
    _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
126
    _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
127
    _KEY_OF_USE_BIAS: True,
128
129
130
131
132
133
134
135
136
137
138
139
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
}, {
    _KEY_OF_NUM_HEADS: 8,
    _KEY_OF_NUM_GQA_GROUPS: 4,
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_MLP_ACTIVATIONS: (('silu',)),
140
    _KEY_OF_USE_BIAS: True,
141
142
}, {
    _KEY_OF_TRANSPOSE_BS: False,
143
144
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_NUM_GQA_GROUPS: 1,
145
    _KEY_OF_ENABLE_ROPE: True,
146
147
    _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
    _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
148
149
}, {
    _KEY_OF_TRANSPOSE_BS: True,
150
    _KEY_OF_ENABLE_ROPE: True,
151
152
    _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
    _KEY_OF_USE_BIAS: True,
153
154
155
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
156
    _KEY_OF_NUM_GQA_GROUPS: 2,
157
    _KEY_OF_ENABLE_ROPE: True,
158
159
160
    _KEY_OF_ROPE_GROUP_METHOD: "alternate",
    _KEY_OF_USE_BIAS: True,
    _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
161
162
}, {
    _KEY_OF_TRANSPOSE_BS: True,
163
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
164
    _KEY_OF_ENABLE_ROPE: True,
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    _KEY_OF_ROPE_GROUP_METHOD: "alternate",
    _KEY_OF_USE_BIAS: True,
}, {
    _KEY_OF_HIDDEN_DROPOUT: 0.3,
    _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
    _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
    _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
}, {
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
    _KEY_OF_USE_BIAS: True,
}, {
    _KEY_OF_RELATIVE_EMBEDDING: False,
    _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, {
    _KEY_OF_ATTENTION_DROPOUT: 0.3,
180
181
}, {
    _KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')),
182
183
184
185
186
}]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


187
188
189
190
191
class BaseRunner:
    """Base runner to define forward and backward tests"""
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
192

193
194
195
196
197
198
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'):
            os.environ['NVTE_FUSED_ATTN'] = "0"
199

200
201
202
203
204
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
        self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng}
        self.apply_rng = {'dropout': apply_dropout_rng}
205

206
207
208
209
210
211
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
        others, params = flax.core.pop(variables, 'params')
        del variables
        return layer, params, others
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
        variables = {'params': params, **others}
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

    def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

        ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)

    def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

        ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
250

Ming-Xu Huang's avatar
Ming-Xu Huang committed
251
        if FP8Helper.is_fp8_enabled():
252
            for _ in range(4):
253
254
255
256
257
258
259
                _, tmp_grad = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
260
                _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
261
262
263
264
                test_others = FP8Helper.update_collections(
                    {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
                del tmp_grad, fp8_meta_grad

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others,
                                                    ref_layer)
        test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others,
                                                       test_layer)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol)


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
        'attention/qkv/scale': 'pre_attention_layer_norm/scale',
        'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias',
        'attention/query/scale': 'pre_attention_layer_norm/scale',
        'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias',
        'mlp/wi_kernel': 'mlp/wi/kernel',
        'mlp/wi_bias': 'mlp/wi/bias',
        'mlp/wo_kernel': 'mlp/wo/kernel',
        'mlp/wo_bias': 'mlp/wo/bias',
        'mlp/scale': 'pre_mlp_layer_norm/scale',
        'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
301
302
303
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
304

305
306
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
307
308
309

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
            mask = causal_mask
        else:
            mask = padded_mask

        ref_masks = (1 - mask,)
        test_masks = (None, mask)    # The second arg of Transformer is encoded tokens.

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
        'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale',
        'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
        'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale',
        'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
        'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale',
        'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
        'self_attention/query/scale': 'pre_self_attention_layer_norm/scale',
        'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
        'mlp/wi_kernel': 'mlp/wi/kernel',
        'mlp/wi_bias': 'mlp/wi/bias',
        'mlp/wo_kernel': 'mlp/wo/kernel',
        'mlp/wo_bias': 'mlp/wo/bias',
        'mlp/scale': 'pre_mlp_layer_norm/scale',
        'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
349
350
351
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
352

353
354
355
356
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
        inputs = (jax.random.normal(data_rng_0, data_shape,
                                    dtype), jax.random.normal(data_rng_1, data_shape, dtype))
357
358
359

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
360
361
362
363
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
            self_mask = causal_mask
        else:
            self_mask = padded_mask
364

365
366
367
368
369
370
371
372
373
374
375
376
377
378
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
class BaseTester():
    """
    Pytest interface to invoke the runner
    """
    runner = BaseRunner
379
380

    def test_forward(self, data_shape, dtype, attrs):
381
        """Test normal datatype forward"""
382
        FP8Helper.finalize()    # Ensure FP8 disabled.
383
384
385
386
387
388
        self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
389

390
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
391
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
392
393
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
        """Test forward with fp8 enabled"""
394
        FP8Helper.initialize(fp8_format=fp8_format)
395
        self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
396
397
        FP8Helper.finalize()

398
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
399
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
400
401
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
        """Test backward with fp8 enabled"""
402
        FP8Helper.initialize(fp8_format=fp8_format)
403
        self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
404
        FP8Helper.finalize()
405
406
407
408
409
410
411
412
413
414
415
416
417
418


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
    runner = DecoderRunner