test_layer.py 16.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict, Tuple, Optional
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
15
16
17
18
19
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    dtype_tols,
    sync_params_values,
)
20
21
22
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

23
from transformer_engine.common import recipe
zlsh80826's avatar
zlsh80826 committed
24
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
25
26
27
28
29
30
from transformer_engine.jax.quantize import (
    QuantizeConfig,
    ScalingMode,
    is_fp8_available,
    update_collections,
)
31

32

33
@pytest.fixture(autouse=True, scope="function")
34
def enable_fused_attn():
35
    """Enable fused attention"""
36
37
38
39
40
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


41
is_fp8_supported, reason = is_fp8_available()
42
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
43
44
45
46
47
48
49
50
51

QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
if is_fp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))


52
53
DATA_SHAPE = [  # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id="32-128-1024"),
54
]
55
DTYPE = [jnp.bfloat16]
56
57
58
59
60

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
61
62
63
64
65
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
66
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
67
68
69
70
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
71
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
72
73
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
74
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
75
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
76
77
78
79
80
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
81
_KEY_OF_WINDOW_SIZE = "window_size"
82

zlsh80826's avatar
zlsh80826 committed
83
84
85
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
86
    _KEY_OF_HIDDEN_DROPOUT: 0,
87
    _KEY_OF_ATTENTION_DROPOUT: 0.0,
88
89
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
90
    _KEY_OF_LAYERNORM_TYPE: "layernorm",
91
    _KEY_OF_WINDOW_SIZE: (-1, -1),
zlsh80826's avatar
zlsh80826 committed
92
}
93

94
ATTRS = [
95
    # attrs0
96
    {},
97
    # attrs1
98
99
100
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
    },
101
    # attrs2
102
103
104
105
    {
        _KEY_OF_ZERO_CENTERED_GAMMA: True,
        _KEY_OF_LAYERNORM_EPS: 1e-2,
    },
106
    # attrs3
107
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
108
    # attrs4
109
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
110
    # attrs5
111
112
113
114
115
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
        _KEY_OF_OUTPUT_LAYERNORM: True,
    },
116
    # attrs6
117
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
118
    # attrs7
119
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
120
    # attrs8
121
122
123
124
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
125
    # attrs9
126
127
128
129
130
131
132
133
    {
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_HIDDEN_DROPOUT: 0.8,
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
        _KEY_OF_USE_BIAS: True,
    },
134
    # attrs10
135
136
137
138
139
140
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
141
    # attrs11
142
143
144
145
146
147
148
149
    {
        _KEY_OF_NUM_HEADS: 8,
        _KEY_OF_NUM_GQA_GROUPS: 4,
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
        _KEY_OF_USE_BIAS: True,
    },
150
    # attrs12
151
152
153
154
155
156
157
158
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_NUM_GQA_GROUPS: 1,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
159
    # attrs13
160
161
162
163
164
165
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_USE_BIAS: True,
    },
166
    # attrs14
167
168
169
170
171
172
173
174
175
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "layernorm",
        _KEY_OF_NUM_GQA_GROUPS: 2,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
176
    # attrs15
177
178
179
180
181
182
183
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
    },
184
    # attrs16
185
186
187
188
189
190
    {
        _KEY_OF_HIDDEN_DROPOUT: 0.3,
        _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
    },
191
    # attrs17
192
193
194
195
    {
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_USE_BIAS: True,
    },
196
    # attrs18
197
198
199
200
    {
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
    },
201
    # attrs19
202
203
204
    {
        _KEY_OF_ATTENTION_DROPOUT: 0.3,
    },
205
    # attrs20
206
207
208
    {
        _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
    },
209
    # attrs21
210
211
212
213
214
215
216
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: (64, 0),  # Left size must < DATA_SHAPE seqlen
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
217
    # attrs22
218
219
220
221
222
223
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
224
]
225
226
227
228

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


229
230
class BaseRunner:
    """Base runner to define forward and backward tests"""
231

232
233
234
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
235

236
237
238
239
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
240
241
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
            os.environ["NVTE_FUSED_ATTN"] = "0"
242

243
244
245
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
246
247
        self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
        self.apply_rng = {"dropout": apply_dropout_rng}
248

249
250
251
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
252
        others, params = flax.core.pop(variables, "params")
253
254
        del variables
        return layer, params, others
255

256
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
257
        variables = {"params": params, **others}
258
259
260
261
262
263
264
265
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

266
267
268
269
270
271
272
    def test_forward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
273
274
275
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

276
277
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
278
279
280
281
282
283
284
285

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

286
287
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
288

289
290
291
292
293
294
295
    def test_backward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
296
297
298
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

299
300
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
301
302
303
304
305

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
306

307
        if QuantizeConfig.is_fp8_enabled():
308
            for _ in range(4):
309
                _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
310
311
312
313
314
315
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
316
                if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
317
318
319
320
321
322
323
324
                    _, updated_quantize_meta = flax.core.pop(
                        updated_state[0], QuantizeConfig.COLLECTION_NAME
                    )
                    test_others = update_collections(
                        {QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
                    )
                    del updated_quantize_meta
                del updated_state
325

326
327
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

328
329
330
331
332
333
        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
            inputs, ref_masks, ref_params, ref_others, ref_layer
        )
        test_out, (test_dgrads, test_wgrads) = grad_fn(
            inputs, test_masks, test_params, test_others, test_layer
        )
334

335
336
337
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
338
339

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
340
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
341
342
343
344


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
345

346
347
348
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
349
350
351
352
353
354
355
356
357
358
        "attention/qkv/scale": "pre_attention_layer_norm/scale",
        "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
        "attention/query/scale": "pre_attention_layer_norm/scale",
        "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
359
360
361
362
363
364
365
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
366
367
368
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
369

370
371
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
372
373
374

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
375
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
376
377
378
379
380
            mask = causal_mask
        else:
            mask = padded_mask

        ref_masks = (1 - mask,)
381
        test_masks = (None, mask)  # The second arg of Transformer is encoded tokens.
382
383
384
385
386
387
388
389

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
390

391
392
393
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
408
409
410
411
412
413
414
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
415
416
417
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
418

419
420
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
421
422
423
424
        inputs = (
            jax.random.normal(data_rng_0, data_shape, dtype),
            jax.random.normal(data_rng_1, data_shape, dtype),
        )
425
426
427

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
428
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
429
430
431
            self_mask = causal_mask
        else:
            self_mask = padded_mask
432

433
434
435
436
437
438
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


439
440
441
442
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
443
444
445
    """
    Pytest interface to invoke the runner
    """
446

447
    runner = BaseRunner
448
449

    def test_forward(self, data_shape, dtype, attrs):
450
        """Test normal datatype forward"""
451
        QuantizeConfig.finalize()  # Ensure FP8 disabled.
452
        self.runner(attrs).test_forward(data_shape, dtype)
453
454
455

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
456
        QuantizeConfig.finalize()  # Ensure FP8 disabled.
457
        self.runner(attrs).test_backward(data_shape, dtype)
458

459
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
460
461
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
462
        """Test forward with fp8 enabled"""
463
        QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
464
        self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
465
        QuantizeConfig.finalize()
466

467
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
468
469
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
470
        """Test backward with fp8 enabled"""
471
        QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
472
        self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
473
        QuantizeConfig.finalize()
474
475
476
477
478
479


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
480

481
482
483
484
485
486
487
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
488

489
    runner = DecoderRunner