test_custom_call_compute.py 27.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
7
import functools
import operator
8
from typing import Callable, List, Sequence, Union
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
20
21
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
22
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
23
from transformer_engine.jax import cpp_extensions as tex
24

25

Tim Moon's avatar
Tim Moon committed
26
27
28
29
30
31
32
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
33
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
34
35
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
36
is_fp8_supported, reason = is_fp8_available()
37

38

39
40
class TestFP8Dot:

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

56
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
57
    def test_qdq(self):
58
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
59
60
61
62
63
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

64
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
65
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
66

67
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
68

69
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
70
71
72
73
74
75
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

76
        primitive_out = type_safe_dot_general(a, b)
77
78
        ref_out = jnp.dot(a, b)

79
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
80

81
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
82
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
83
    def test_forward_fp8_randint(self, m, n, k):
84
85
86
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

87
88
        dtype = jnp.bfloat16

89
90
        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
91
92
93
94
95
96
97
98
99
100
101
102
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)

        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
        fp8_meta_pkg = FP8MetaPackage(
            amax_list[0],
            scale_list[0],
            amax_list[1],
            scale_list[1],
            amax_list[2],
            scale_list[2],
        )
103
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
104
105
106
107
108
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

109
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
110

111
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
112
113
114
115
116
117
118
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
119
120
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
121
122
123
124
125
126
127
128
129
130
131

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

132
133
134
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
135

136
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
137
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
138
    def test_grad_fp8_dot(self, m, n, k):
139
140
141
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

142
143
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
144

145
146
147
148
149
150
151
152
153
154
155
        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()

        def primitive_func(x, y, amax_list, scale_list):
            fp8_meta_pkg = FP8MetaPackage(
                amax_list[0],
                scale_list[0],
                amax_list[1],
                scale_list[1],
                amax_list[2],
                scale_list[2],
            )
156
157
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
158
159

        def ref_func(x, y):
160
            return jnp.mean(jnp.dot(x, y))
161

162
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
163
164
165
166
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

167
        for _ in range(3):
168
169
170
            primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
                value_n_grad_primitive_func(a, b, amax_list, scale_list)
            )
171

172
173
174
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
175

176
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    @pytest.mark.parametrize(
        "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
    )
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_grad_fused_layernorm_fp8_mlp(
        self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
    ):
        """N/a"""
200
        key = jax.random.PRNGKey(0)
201
202
        subkeys = jax.random.split(key, 6)

203
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
204
205
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
206
207
208
209
210
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
211
212
            b1 = None
            b2 = None
213

214
215
216
        def primitive_func(
            x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
        ):
217
218
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
219
            # out = ((x * y) + w) * z + v
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
236
            return jnp.mean(
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                fused_layernorm_fp8_mlp(
                    x,
                    ln_s,
                    None,
                    [y, z],
                    [w, v],
                    [fp8_meta_pkg_1, fp8_meta_pkg_2],
                    "rmsnorm",
                    activation_type=activation_type,
                    use_bias=use_bias,
                )
            )

        def layernorm_fp8_mlp_ref(
            x: jnp.ndarray,
            ln_scale: jnp.ndarray,
            kernel_1: jnp.ndarray,
            kernel_2: jnp.ndarray,
            bias_1: jnp.ndarray,
            bias_2: jnp.ndarray,
            amax_list_1: List[jnp.ndarray],
            amax_list_2: List[jnp.ndarray],
            scale_list_1: List[jnp.ndarray],
            scale_list_2: List[jnp.ndarray],
        ) -> jnp.ndarray:
262
263
264
265
266
267
268

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

269
270
271
272
273
274
275
276
277
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
278

279
280
281
282
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

283
            x = _jax_act_lu(linear_1_out, activation_type)
284

285
286
287
288
289
290
291
292
293
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
            output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
294

295
296
297
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
298
299
300

            return output

301
        def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
302
            return jnp.mean(
303
304
305
306
                layernorm_fp8_mlp_ref(
                    x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
                )
            )
307
308

        value_n_grad_primitive_func = jit(
309
310
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
        )
311
312
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

313
314
315
316
317
318
319
        _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
        _, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()

        ref_amax_list_1 = amax_list_1
        ref_scale_list_1 = scale_list_1
        ref_amax_list_2 = amax_list_2
        ref_scale_list_2 = scale_list_2
320

321
322
323
324
325
326
        primitive_amax_list_1 = amax_list_1
        primitive_scale_list_1 = scale_list_1
        primitive_amax_list_2 = amax_list_2
        primitive_scale_list_2 = scale_list_2

        primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
327

328
        # Convert str to index as str is not a valid type for JAX JIT
329
        for _ in range(3):
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            ref_out, (
                ref_a_grad,
                ref_s_grad,
                ref_k1_grad,
                ref_k2_grad,
                ref_b1_grad,
                ref_b2_grad,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            ) = value_n_grad_ref_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            )
353
354

        for _ in range(3):
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            primitive_out, (
                primitive_a_grad,
                primitive_s_grad,
                primitive_k1_grad,
                primitive_k2_grad,
                primitive_b1_grad,
                primitive_b2_grad,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            ) = value_n_grad_primitive_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            )
378
379

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        assert_allclose(
            jnp.asarray(primitive_a_grad, np.float32),
            jnp.asarray(ref_a_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k1_grad, np.float32),
            jnp.asarray(ref_k1_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_s_grad, np.float32),
            jnp.asarray(ref_s_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k2_grad, np.float32),
            jnp.asarray(ref_k2_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
400
        if use_bias:
401
402
403
404
405
406
407
408
409
410
            assert_allclose(
                jnp.asarray(primitive_b2_grad, np.float32),
                jnp.asarray(ref_b2_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
            assert_allclose(
                jnp.asarray(primitive_b1_grad, np.float32),
                jnp.asarray(ref_b1_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
411

412

413
414
415
416
417
418
419
420
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


421
class TestActivationLu:
422

423
    def ref_func(self, x, activation_type):
424

425
        def ref_act_lu(inputs):
426
            x = _jax_act_lu(inputs, activation_type)
427
            return jnp.mean(x)
428

429
430
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
431

432
    def primitive_func(self, inputs):
433
        return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
451
    def test_activation_lu(self, random_inputs, activation_type):
452
        x = random_inputs
453
        x = jnp.repeat(x, len(activation_type), axis=1)
454
        self.activation_type = activation_type
455

456
        value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
457

458
459
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
460

461
462
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
463
464


465
class TestActivationLuFP8(TestActivationLu):
466

467
468
469
470
471
472
473
474
475
476
477
478
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
479
480
481
            activation_lu_out, _ = tex.act_lu_fp8(
                x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
            )
482
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
483
            ctx = x
484
485
486
487
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
488
489
490
491
            if len(self.activation_type) > 1:  # gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
                    g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
                )
492
                dbias = jnp.empty(x.shape[-1], x.dtype)
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            else:  # not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
                    tex.dact_lu_dbias_cast_transpose(
                        g,
                        x,
                        amax,
                        scale,
                        scale_inv,
                        FP8Helper.BWD_DTYPE,
                        -1,
                        -2,
                        self.activation_type,
                    )
                )
507
508
509
510
511
512
513
514
515
516
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
517
518
519
        value_n_grad_primitive_func = value_and_grad(
            lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
        )
520
521
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

522
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
539
    def test_activation_lu(self, random_inputs, activation_type):
540
541
542
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
543
        self.activation_type = activation_type
544
        self.transpose_indices = (1, 2, 0)
545
546

        x = random_inputs
547
        x = jnp.repeat(x, len(activation_type), axis=1)
548

549
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
550
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
551

552
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
553
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
554
        if "linear" not in activation_type:
555
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
556
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
557
558
559
560
561
        assert_allclose(
            prim_grad_trans,
            jnp.transpose(ref_grad, self.transpose_indices),
            dtype=FP8Helper.BWD_DTYPE,
        )
562
563


564
565
566
567
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
568

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

584
585
586
587
588
589
590
591
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
592
            mean = 0.0
593
594
595
596
597
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
598
            scale += 1.0
599
        if bias is None:
600
            bias = 0.0
601
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
602

603
604
605
606
607
608
609
610
    @pytest.mark.parametrize("n, hidden", LN_CASES)
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [False, True])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
    def test_layernorm_forward_backward(
        self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
    ):
611
612
613
614
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
615
        if ln_type == "rmsnorm" and zero_centered_gamma:
616
617
618
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

619
620
621
622
623
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
624
625
626
627
628
629
630
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
631
            if ln_type == "layernorm":
632
633
634
635
636
637
638
639
640
641
642
643
644
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
645
646
647
648
649
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
650
651
652
653

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
654
655
656
657
658
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
659

660
661
662
663
664
665
            primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
                x, gamma, beta
            )
            reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
                x, gamma, beta
            )
666
667
668
669
670
671
672
673

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
674
675
676
677
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [True, False])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
678
679
680
681
682
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
683
        if ln_type == "rmsnorm" and zero_centered_gamma:
684
685
686
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

687
688
689
690
691
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
692
693
694
695
696
697
698
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
699
            if ln_type == "layernorm":
700
701
702
703
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

704
705
706
707
708
709
710
711
712
713
714
            _, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()

            def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
                fp8_meta_pkg = FP8MetaPackage(
                    amax_list_1[0],
                    scale_list_1[0],
                    amax_list_1[1],
                    scale_list_1[1],
                    amax_list_1[2],
                    scale_list_1[2],
                )
715
716
717
                primitive_out = layernorm_fp8_dot(
                    x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
                )
718
719
720
721
722
723
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

724
            value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
725
726
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

727
728
729
            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
                value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
            )
730
731

            for _ in range(3):
732
733
734
735
736
737
738
739
                primitive_out, (
                    primitive_a_grad,
                    primitive_b_grad,
                    primitive_gamma_grad,
                    primitive_beta_grad,
                    amax_list_1,
                    scale_list_1,
                ) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
740
741
742
743
744
745
746

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)