test_layer.py 15.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
from utils import assert_allclose, assert_tree_like_allclose, sync_params_values
15
16
17
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

zlsh80826's avatar
zlsh80826 committed
18
19
20
21
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available

22
23
is_fp8_supported, reason = is_fp8_available()

24

25
@pytest.fixture(autouse=True, scope='function')
26
def enable_fused_attn():
27
    """Enable fused attention"""
28
29
30
31
32
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


33
34
35
36
DATA_SHAPE = [    # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id='32-128-1024'),
    pytest.param((32, 512, 1024), id='32-512-1024'),
]
37
38
39
40
41
42
43
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
44
45
46
47
48
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
49
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
50
51
52
53
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
54
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
55
56
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
57
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
58
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
59
60
61
62
63
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
64

zlsh80826's avatar
zlsh80826 committed
65
66
67
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
68
69
70
71
72
    _KEY_OF_HIDDEN_DROPOUT: 0,
    _KEY_OF_ATTENTION_DROPOUT: 0,
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
zlsh80826's avatar
zlsh80826 committed
73
}
74

75
ATTRS = [{}, {
76
77
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
78
79
    _KEY_OF_ZERO_CENTERED_GAMMA: True,
    _KEY_OF_LAYERNORM_EPS: 1e-2,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
    _KEY_OF_OUTPUT_LAYERNORM: True
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_DROP_PATH: 0.1
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_FUSE_QKV_PARAMS: False
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
98
    _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
99
100
101
}, {
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
102
103
104
105
    _KEY_OF_HIDDEN_DROPOUT: 0.8,
    _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
    _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
    _KEY_OF_USE_BIAS: True,
106
107
108
109
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
110
    _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
zlsh80826's avatar
zlsh80826 committed
111
112
}, {
    _KEY_OF_NUM_HEADS: 8,
113
114
115
    _KEY_OF_NUM_GQA_GROUPS: 4,
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
116
117
    _KEY_OF_MLP_ACTIVATIONS: ('gelu',),
    _KEY_OF_USE_BIAS: True,
118
119
120
121
122
123
}, {
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
}, {
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
124
125
    _KEY_OF_HIDDEN_DROPOUT: 0.8,
    _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
126
    _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
127
    _KEY_OF_USE_BIAS: True,
128
129
130
131
132
133
134
135
136
137
138
139
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
}, {
    _KEY_OF_NUM_HEADS: 8,
    _KEY_OF_NUM_GQA_GROUPS: 4,
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_SCALE_ATTN_LOGITS: True,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
    _KEY_OF_MLP_ACTIVATIONS: (('silu',)),
140
    _KEY_OF_USE_BIAS: True,
141
142
}, {
    _KEY_OF_TRANSPOSE_BS: False,
143
144
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
    _KEY_OF_NUM_GQA_GROUPS: 1,
145
    _KEY_OF_ENABLE_ROPE: True,
146
147
    _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
    _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
148
149
}, {
    _KEY_OF_TRANSPOSE_BS: True,
150
    _KEY_OF_ENABLE_ROPE: True,
151
152
    _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
    _KEY_OF_USE_BIAS: True,
153
154
155
}, {
    _KEY_OF_TRANSPOSE_BS: False,
    _KEY_OF_LAYERNORM_TYPE: 'layernorm',
156
    _KEY_OF_NUM_GQA_GROUPS: 2,
157
    _KEY_OF_ENABLE_ROPE: True,
158
159
160
    _KEY_OF_ROPE_GROUP_METHOD: "alternate",
    _KEY_OF_USE_BIAS: True,
    _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
161
162
}, {
    _KEY_OF_TRANSPOSE_BS: True,
163
    _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
164
    _KEY_OF_ENABLE_ROPE: True,
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    _KEY_OF_ROPE_GROUP_METHOD: "alternate",
    _KEY_OF_USE_BIAS: True,
}, {
    _KEY_OF_HIDDEN_DROPOUT: 0.3,
    _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
    _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
    _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
}, {
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
    _KEY_OF_USE_BIAS: True,
}, {
    _KEY_OF_RELATIVE_EMBEDDING: False,
    _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, {
    _KEY_OF_ATTENTION_DROPOUT: 0.3,
180
181
182
183
184
}]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


185
186
187
188
189
class BaseRunner:
    """Base runner to define forward and backward tests"""
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
190

191
192
193
194
195
196
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'):
            os.environ['NVTE_FUSED_ATTN'] = "0"
197

198
199
200
201
202
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
        self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng}
        self.apply_rng = {'dropout': apply_dropout_rng}
203

204
205
206
207
208
209
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
        others, params = flax.core.pop(variables, 'params')
        del variables
        return layer, params, others
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
        variables = {'params': params, **others}
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

    def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

        ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)

    def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

        ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
248

Ming-Xu Huang's avatar
Ming-Xu Huang committed
249
        if FP8Helper.is_fp8_enabled():
250
            for _ in range(4):
251
252
253
254
255
256
257
                _, tmp_grad = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
258
                _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
259
260
261
262
263
                test_others = FP8Helper.update_collections(
                    {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
                test_others = FP8Helper.update_fp8_metas(test_others)
                del tmp_grad, fp8_meta_grad

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others,
                                                    ref_layer)
        test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others,
                                                       test_layer)

        assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol)


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
        'attention/qkv/scale': 'pre_attention_layer_norm/scale',
        'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias',
        'attention/query/scale': 'pre_attention_layer_norm/scale',
        'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias',
        'mlp/wi_kernel': 'mlp/wi/kernel',
        'mlp/wi_bias': 'mlp/wi/bias',
        'mlp/wo_kernel': 'mlp/wo/kernel',
        'mlp/wo_bias': 'mlp/wo/bias',
        'mlp/scale': 'pre_mlp_layer_norm/scale',
        'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
300
301
302
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
303

304
305
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
306
307
308

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
            mask = causal_mask
        else:
            mask = padded_mask

        ref_masks = (1 - mask,)
        test_masks = (None, mask)    # The second arg of Transformer is encoded tokens.

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
        'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale',
        'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
        'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale',
        'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
        'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale',
        'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
        'self_attention/query/scale': 'pre_self_attention_layer_norm/scale',
        'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
        'mlp/wi_kernel': 'mlp/wi/kernel',
        'mlp/wi_bias': 'mlp/wi/bias',
        'mlp/wo_kernel': 'mlp/wo/kernel',
        'mlp/wo_bias': 'mlp/wo/bias',
        'mlp/scale': 'pre_mlp_layer_norm/scale',
        'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
348
349
350
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
351

352
353
354
355
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
        inputs = (jax.random.normal(data_rng_0, data_shape,
                                    dtype), jax.random.normal(data_rng_1, data_shape, dtype))
356
357
358

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
359
360
361
362
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
            self_mask = causal_mask
        else:
            self_mask = padded_mask
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
class BaseTester():
    """
    Pytest interface to invoke the runner
    """
    runner = BaseRunner
378
379

    def test_forward(self, data_shape, dtype, attrs):
380
        """Test normal datatype forward"""
381
        FP8Helper.finalize()    # Ensure FP8 disabled.
382
383
384
385
386
387
        self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
        FP8Helper.finalize()    # Ensure FP8 disabled.
        self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
388

389
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
390
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
391
392
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
        """Test forward with fp8 enabled"""
393
        FP8Helper.initialize(fp8_format=fp8_format)
394
        self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
395
396
        FP8Helper.finalize()

397
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
398
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
399
400
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
        """Test backward with fp8 enabled"""
401
        FP8Helper.initialize(fp8_format=fp8_format)
402
        self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
403
        FP8Helper.finalize()
404
405
406
407
408
409
410
411
412
413
414
415
416
417


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
    runner = DecoderRunner