test_custom_call_compute.py 28.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
7
import functools
import operator
8
from typing import Callable, List, Sequence, Union
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformer_engine.jax.dot import (
    type_safe_dot_general,
    dequantize,
    quantize
)
from transformer_engine.jax.fp8 import (
    FP8MetaPackage,
    FP8Helper,
    is_fp8_available
)
from transformer_engine.jax.layernorm import (
    layernorm,
    layernorm_fp8_dot
)
from transformer_engine.jax.layernorm_mlp import (
    activation_lu,
    fused_layernorm_fp8_mlp
)
from transformer_engine.jax import cpp_extensions as tex
37

Tim Moon's avatar
Tim Moon committed
38
39
40
41
42
43
44
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
45
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
46
47
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
48
is_fp8_supported, reason = is_fp8_available()
49

50

51
52
53
54
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
55
56
57
58
    if fn_or_string == 'quick_gelu':
        return lambda x: nn.gelu(x, approximate=True)
    if fn_or_string == 'squared_relu':
        return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
59
60
61
62
63
64
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

65
66
67

class TestFP8Dot:

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

83
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
84
    def test_qdq(self):
85
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
86
87
88
89
90
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

91
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
92
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
93

94
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
95
96
97
98
99
100
101
102

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

103
        primitive_out = type_safe_dot_general(a, b)
104
105
        ref_out = jnp.dot(a, b)

106
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
107

108
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
109
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
110
    def test_forward_fp8_randint(self, m, n, k):
111
112
113
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

114
115
        dtype = jnp.bfloat16

116
117
        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
118
119
120
121
122
123
124
125
126
127
128
129
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)

        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
        fp8_meta_pkg = FP8MetaPackage(
            amax_list[0],
            scale_list[0],
            amax_list[1],
            scale_list[1],
            amax_list[2],
            scale_list[2],
        )
130
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
131
132
133
134
135
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

136
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
137
138
139
140
141
142
143
144
145

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
146
147
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
148
149
150
151
152
153
154
155
156
157
158

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

159
160
161
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
162

163
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
164
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
165
    def test_grad_fp8_dot(self, m, n, k):
166
167
168
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

169
170
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
171

172
173
174
175
176
177
178
179
180
181
182
        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()

        def primitive_func(x, y, amax_list, scale_list):
            fp8_meta_pkg = FP8MetaPackage(
                amax_list[0],
                scale_list[0],
                amax_list[1],
                scale_list[1],
                amax_list[2],
                scale_list[2],
            )
183
184
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
185
186

        def ref_func(x, y):
187
            return jnp.mean(jnp.dot(x, y))
188

189
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
190
191
192
193
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

194
        for _ in range(3):
195
196
            primitive_out, (primitive_a_grad, primitive_b_grad, amax_list,
                            scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list)
197

198
199
200
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
201

202
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
203
    @pytest.mark.parametrize('m,n,k', [(256, 128, 512),
204
205
                                       (16384, 1024, 2816),
                                       (16384, 2816, 1024),
206
                                       (16384, 1024, 1024)])
207
    @pytest.mark.parametrize('activation_type', [('gelu', ),
208
209
                                                 ('gelu', 'linear'),
                                                 ('silu', ),
210
211
212
213
214
215
216
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear')])
217
    @pytest.mark.parametrize('use_bias', [True, False])
218
219
220
    def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str,
                                                                                         Callable]],
                                          use_bias: bool):
221
        """  N/a """
222
        key = jax.random.PRNGKey(0)
223
224
        subkeys = jax.random.split(key, 6)

225
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
226
227
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
228
229
230
231
232
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
233
234
            b1 = None
            b2 = None
235

236
237
        def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
                           scale_list_2):
238
239
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
240
            # out = ((x * y) + w) * z + v
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
257
            return jnp.mean(
258
259
260
261
262
263
                fused_layernorm_fp8_mlp(x,
                                        ln_s,
                                        None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2],
                                        "rmsnorm",
                                        activation_type=activation_type,
                                        use_bias=use_bias))
264

265
        def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
266
267
268
269
                                  kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
                                  amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray],
                                  scale_list_1: List[jnp.ndarray],
                                  scale_list_2: List[jnp.ndarray]) -> jnp.ndarray:
270
271
272
273
274
275
276

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

277
278
279
280
281
282
283
284
285
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
286

287
288
289
290
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

291
292
293
294
295
296
            x = jnp.split(linear_1_out, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
297
298
299

            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

300
301
302
303
304
305
306
307
308
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
            output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
309

310
311
312
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
313
314
315

            return output

316
        def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
317
            return jnp.mean(
318
319
                layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
                                      scale_list_2))
320
321
322
323
324

        value_n_grad_primitive_func = jit(
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

325
326
327
328
329
330
331
        _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
        _, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()

        ref_amax_list_1 = amax_list_1
        ref_scale_list_1 = scale_list_1
        ref_amax_list_2 = amax_list_2
        ref_scale_list_2 = scale_list_2
332

333
334
335
336
337
338
        primitive_amax_list_1 = amax_list_1
        primitive_scale_list_1 = scale_list_1
        primitive_amax_list_2 = amax_list_2
        primitive_scale_list_2 = scale_list_2

        primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
339

340
        # Convert str to index as str is not a valid type for JAX JIT
341
342
        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
343
344
345
346
                      ref_amax_list_1, ref_amax_list_2, ref_scale_list_1,
                      ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2,
                                                                ref_amax_list_1, ref_amax_list_2,
                                                                ref_scale_list_1, ref_scale_list_2)
347
348
349

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
350
351
352
353
354
                            primitive_k2_grad, primitive_b1_grad, primitive_b2_grad,
                            primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1,
                            primitive_scale_list_2) = value_n_grad_primitive_func(
                                a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2,
                                primitive_scale_list_1, primitive_scale_list_2)
355
356
357
358
359
360
361
362
363
364
365

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
366
367
368
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
369
370
371
        if use_bias:
            assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
                            jnp.asarray(ref_b2_grad, np.float32),
372
373
374
375
                            dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
                            jnp.asarray(ref_b1_grad, np.float32),
                            dtype=FP8Helper.BWD_DTYPE)
376

377

378
379
380
381
382
383
384
385
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


386
class TestActivationLu:
387

388
    def ref_func(self, x, activation_type):
389

390
391
392
393
394
395
396
397
        def ref_act_lu(inputs):
            x = jnp.split(inputs, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            return jnp.mean(x)
398

399
400
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
401

402
    def primitive_func(self, inputs):
403
        return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
404

405
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
406
407
408
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
409
410
411
412
413
414
415
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear') ])
416
    def test_activation_lu(self, random_inputs, activation_type):
417
        x = random_inputs
418
        x = jnp.repeat(x, len(activation_type), axis=1)
419
        self.activation_type = activation_type
420

421
        value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
422

423
424
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
425

426
427
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
428
429


430
class TestActivationLuFP8(TestActivationLu):
431

432
433
434
435
436
437
438
439
440
441
442
443
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
444
            activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv,
445
446
447
448
449
450
451
452
453
                                              FP8Helper.FWD_DTYPE, activation_type)
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
            ctx = (x)
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
            if len(self.activation_type) > 1: #gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = \
454
                tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
455
456
457
458
                                             FP8Helper.BWD_DTYPE, -1, activation_type)
                dbias = jnp.empty(x.shape[-1], x.dtype)
            else: #not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
459
                tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
                                             -1, -2, self.activation_type)
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
        value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
            jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

475
476

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
477
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
478
479
480
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
481
482
483
484
485
486
487
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear') ])
488
    def test_activation_lu(self, random_inputs, activation_type):
489
490
491
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
492
        self.activation_type = activation_type
493
        self.transpose_indices = (1, 2, 0)
494
495

        x = random_inputs
496
        x = jnp.repeat(x, len(activation_type), axis=1)
497

498
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
499
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
500

501
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
502
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
503
504
        if 'linear' not in activation_type:
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
505
506
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
507
                        jnp.transpose(ref_grad, self.transpose_indices),
508
                        dtype=FP8Helper.BWD_DTYPE)
509
510


511
512
513
514
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
            mean = 0.
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
            scale += 1.
        if bias is None:
            bias = 0.
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
549
550
551

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
552
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
553
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
                                        dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
            if ln_type == 'layernorm':
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            primitive_out, (primitive_dx, primitive_dgamma,
                            primitive_dbeta) = jitted_primitive(x, gamma, beta)
            reference_out, (reference_dx, reference_dgamma,
                            reference_dbeta) = jitted_reference(x, gamma, beta)

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
    @pytest.mark.parametrize('zero_centered_gamma', [True, False])
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
            if ln_type == 'layernorm':
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

636
637
638
639
640
641
642
643
644
645
646
            _, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()

            def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
                fp8_meta_pkg = FP8MetaPackage(
                    amax_list_1[0],
                    scale_list_1[0],
                    amax_list_1[1],
                    scale_list_1[1],
                    amax_list_1[2],
                    scale_list_1[2],
                )
647
648
649
650
651
652
653
654
                primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
                                                  zero_centered_gamma)
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

655
            value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
656
657
658
659
660
661
662
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
                      ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)

            for _ in range(3):
                primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
663
664
665
                                primitive_beta_grad, amax_list_1,
                                scale_list_1) = value_n_grad_primitive_func(
                                    a, b, gamma, beta, amax_list_1, scale_list_1)
666
667
668
669
670
671
672

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)