test_custom_call_compute.py 28.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
7
import functools
import operator
8
from typing import Callable, List, Sequence, Union
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
20
21
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
22
from transformer_engine.jax import cpp_extensions as tex
23

Tim Moon's avatar
Tim Moon committed
24
25
26
27
28
29
30
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
31
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
32
33
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
34
is_fp8_supported, reason = is_fp8_available()
35

36

37
38
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
39
    if fn_or_string == "linear":
40
        return lambda x: x
41
    if fn_or_string == "quick_gelu":
42
        return lambda x: nn.gelu(x, approximate=True)
43
    if fn_or_string == "squared_relu":
44
        return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
45
46
47
48
49
50
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

51
52
53

class TestFP8Dot:

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

69
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
70
    def test_qdq(self):
71
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
72
73
74
75
76
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

77
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
78
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
79

80
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
81

82
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
83
84
85
86
87
88
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

89
        primitive_out = type_safe_dot_general(a, b)
90
91
        ref_out = jnp.dot(a, b)

92
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
93

94
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
95
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
96
    def test_forward_fp8_randint(self, m, n, k):
97
98
99
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

100
101
        dtype = jnp.bfloat16

102
103
        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
104
105
106
107
108
109
110
111
112
113
114
115
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)

        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
        fp8_meta_pkg = FP8MetaPackage(
            amax_list[0],
            scale_list[0],
            amax_list[1],
            scale_list[1],
            amax_list[2],
            scale_list[2],
        )
116
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
117
118
119
120
121
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

122
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
123

124
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
125
126
127
128
129
130
131
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
132
133
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
134
135
136
137
138
139
140
141
142
143
144

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

145
146
147
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
148

149
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
150
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
151
    def test_grad_fp8_dot(self, m, n, k):
152
153
154
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

155
156
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
157

158
159
160
161
162
163
164
165
166
167
168
        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()

        def primitive_func(x, y, amax_list, scale_list):
            fp8_meta_pkg = FP8MetaPackage(
                amax_list[0],
                scale_list[0],
                amax_list[1],
                scale_list[1],
                amax_list[2],
                scale_list[2],
            )
169
170
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
171
172

        def ref_func(x, y):
173
            return jnp.mean(jnp.dot(x, y))
174

175
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
176
177
178
179
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

180
        for _ in range(3):
181
182
183
            primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
                value_n_grad_primitive_func(a, b, amax_list, scale_list)
            )
184

185
186
187
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
188

189
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    @pytest.mark.parametrize(
        "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
    )
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_grad_fused_layernorm_fp8_mlp(
        self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
    ):
        """N/a"""
213
        key = jax.random.PRNGKey(0)
214
215
        subkeys = jax.random.split(key, 6)

216
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
217
218
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
219
220
221
222
223
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
224
225
            b1 = None
            b2 = None
226

227
228
229
        def primitive_func(
            x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
        ):
230
231
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
232
            # out = ((x * y) + w) * z + v
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
249
            return jnp.mean(
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                fused_layernorm_fp8_mlp(
                    x,
                    ln_s,
                    None,
                    [y, z],
                    [w, v],
                    [fp8_meta_pkg_1, fp8_meta_pkg_2],
                    "rmsnorm",
                    activation_type=activation_type,
                    use_bias=use_bias,
                )
            )

        def layernorm_fp8_mlp_ref(
            x: jnp.ndarray,
            ln_scale: jnp.ndarray,
            kernel_1: jnp.ndarray,
            kernel_2: jnp.ndarray,
            bias_1: jnp.ndarray,
            bias_2: jnp.ndarray,
            amax_list_1: List[jnp.ndarray],
            amax_list_2: List[jnp.ndarray],
            scale_list_1: List[jnp.ndarray],
            scale_list_2: List[jnp.ndarray],
        ) -> jnp.ndarray:
275
276
277
278
279
280
281

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

282
283
284
285
286
287
288
289
290
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
291

292
293
294
295
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

296
297
298
299
300
301
            x = jnp.split(linear_1_out, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
302
303
304

            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

305
306
307
308
309
310
311
312
313
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
            output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
314

315
316
317
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
318
319
320

            return output

321
        def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
322
            return jnp.mean(
323
324
325
326
                layernorm_fp8_mlp_ref(
                    x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
                )
            )
327
328

        value_n_grad_primitive_func = jit(
329
330
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
        )
331
332
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

333
334
335
336
337
338
339
        _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
        _, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()

        ref_amax_list_1 = amax_list_1
        ref_scale_list_1 = scale_list_1
        ref_amax_list_2 = amax_list_2
        ref_scale_list_2 = scale_list_2
340

341
342
343
344
345
346
        primitive_amax_list_1 = amax_list_1
        primitive_scale_list_1 = scale_list_1
        primitive_amax_list_2 = amax_list_2
        primitive_scale_list_2 = scale_list_2

        primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
347

348
        # Convert str to index as str is not a valid type for JAX JIT
349
        for _ in range(3):
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            ref_out, (
                ref_a_grad,
                ref_s_grad,
                ref_k1_grad,
                ref_k2_grad,
                ref_b1_grad,
                ref_b2_grad,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            ) = value_n_grad_ref_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            )
373
374

        for _ in range(3):
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
            primitive_out, (
                primitive_a_grad,
                primitive_s_grad,
                primitive_k1_grad,
                primitive_k2_grad,
                primitive_b1_grad,
                primitive_b2_grad,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            ) = value_n_grad_primitive_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            )
398
399

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        assert_allclose(
            jnp.asarray(primitive_a_grad, np.float32),
            jnp.asarray(ref_a_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k1_grad, np.float32),
            jnp.asarray(ref_k1_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_s_grad, np.float32),
            jnp.asarray(ref_s_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k2_grad, np.float32),
            jnp.asarray(ref_k2_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
420
        if use_bias:
421
422
423
424
425
426
427
428
429
430
            assert_allclose(
                jnp.asarray(primitive_b2_grad, np.float32),
                jnp.asarray(ref_b2_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
            assert_allclose(
                jnp.asarray(primitive_b1_grad, np.float32),
                jnp.asarray(ref_b1_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
431

432

433
434
435
436
437
438
439
440
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


441
class TestActivationLu:
442

443
    def ref_func(self, x, activation_type):
444

445
446
447
448
449
450
451
452
        def ref_act_lu(inputs):
            x = jnp.split(inputs, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            return jnp.mean(x)
453

454
455
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
456

457
    def primitive_func(self, inputs):
458
        return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
476
    def test_activation_lu(self, random_inputs, activation_type):
477
        x = random_inputs
478
        x = jnp.repeat(x, len(activation_type), axis=1)
479
        self.activation_type = activation_type
480

481
        value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
482

483
484
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
485

486
487
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
488
489


490
class TestActivationLuFP8(TestActivationLu):
491

492
493
494
495
496
497
498
499
500
501
502
503
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
504
505
506
            activation_lu_out, _ = tex.act_lu_fp8(
                x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
            )
507
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
508
            ctx = x
509
510
511
512
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
513
514
515
516
            if len(self.activation_type) > 1:  # gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
                    g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
                )
517
                dbias = jnp.empty(x.shape[-1], x.dtype)
518
519
520
521
522
523
524
525
526
527
528
529
530
531
            else:  # not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
                    tex.dact_lu_dbias_cast_transpose(
                        g,
                        x,
                        amax,
                        scale,
                        scale_inv,
                        FP8Helper.BWD_DTYPE,
                        -1,
                        -2,
                        self.activation_type,
                    )
                )
532
533
534
535
536
537
538
539
540
541
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
542
543
544
        value_n_grad_primitive_func = value_and_grad(
            lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
        )
545
546
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

547
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
564
    def test_activation_lu(self, random_inputs, activation_type):
565
566
567
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
568
        self.activation_type = activation_type
569
        self.transpose_indices = (1, 2, 0)
570
571

        x = random_inputs
572
        x = jnp.repeat(x, len(activation_type), axis=1)
573

574
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
575
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
576

577
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
578
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
579
        if "linear" not in activation_type:
580
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
581
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
582
583
584
585
586
        assert_allclose(
            prim_grad_trans,
            jnp.transpose(ref_grad, self.transpose_indices),
            dtype=FP8Helper.BWD_DTYPE,
        )
587
588


589
590
591
592
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
593

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

609
610
611
612
613
614
615
616
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
617
            mean = 0.0
618
619
620
621
622
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
623
            scale += 1.0
624
        if bias is None:
625
            bias = 0.0
626
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
627

628
629
630
631
632
633
634
635
    @pytest.mark.parametrize("n, hidden", LN_CASES)
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [False, True])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
    def test_layernorm_forward_backward(
        self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
    ):
636
637
638
639
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
640
        if ln_type == "rmsnorm" and zero_centered_gamma:
641
642
643
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

644
645
646
647
648
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
649
650
651
652
653
654
655
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
656
            if ln_type == "layernorm":
657
658
659
660
661
662
663
664
665
666
667
668
669
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
670
671
672
673
674
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
675
676
677
678

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
679
680
681
682
683
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
684

685
686
687
688
689
690
            primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
                x, gamma, beta
            )
            reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
                x, gamma, beta
            )
691
692
693
694
695
696
697
698

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
699
700
701
702
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [True, False])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
703
704
705
706
707
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
708
        if ln_type == "rmsnorm" and zero_centered_gamma:
709
710
711
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

712
713
714
715
716
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
717
718
719
720
721
722
723
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
724
            if ln_type == "layernorm":
725
726
727
728
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

729
730
731
732
733
734
735
736
737
738
739
            _, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()

            def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
                fp8_meta_pkg = FP8MetaPackage(
                    amax_list_1[0],
                    scale_list_1[0],
                    amax_list_1[1],
                    scale_list_1[1],
                    amax_list_1[2],
                    scale_list_1[2],
                )
740
741
742
                primitive_out = layernorm_fp8_dot(
                    x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
                )
743
744
745
746
747
748
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

749
            value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
750
751
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

752
753
754
            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
                value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
            )
755
756

            for _ in range(3):
757
758
759
760
761
762
763
764
                primitive_out, (
                    primitive_a_grad,
                    primitive_b_grad,
                    primitive_gamma_grad,
                    primitive_beta_grad,
                    amax_list_1,
                    scale_list_1,
                ) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
765
766
767
768
769
770
771

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)