"tests/pytorch/test_transformerengine.py" did not exist on "782101277c90fa3ae800bb31b6465a318f0d954c"
test_custom_call_compute.py 28.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
7
import functools
import operator
8
from typing import Callable, List, Sequence, Union
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
20
from transformer_engine.jax.fp8 import is_fp8_available
21
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
22
23
24
from transformer_engine.jax.mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions import act_lu_fp8, dact_lu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import dgated_act_lu_cast_transpose
25

Tim Moon's avatar
Tim Moon committed
26
27
28
29
30
31
32
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
33
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
34
35
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
36
is_fp8_supported, reason = is_fp8_available()
37

38

39
40
41
42
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
43
44
45
46
    if fn_or_string == 'quick_gelu':
        return lambda x: nn.gelu(x, approximate=True)
    if fn_or_string == 'squared_relu':
        return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
47
48
49
50
51
52
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

53
54
55

class TestFP8Dot:

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

71
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
72
    def test_qdq(self):
73
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
74
75
76
77
78
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

79
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
80
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
81

82
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
83
84
85
86
87
88
89
90

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

91
        primitive_out = type_safe_dot_general(a, b)
92
93
        ref_out = jnp.dot(a, b)

94
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
95

96
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
97
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
98
    def test_forward_fp8_randint(self, m, n, k):
99
100
101
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

102
103
        dtype = jnp.bfloat16

104
105
        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
106
107
108
109
110
111
112
113
114
115
116
117
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)

        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
        fp8_meta_pkg = FP8MetaPackage(
            amax_list[0],
            scale_list[0],
            amax_list[1],
            scale_list[1],
            amax_list[2],
            scale_list[2],
        )
118
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
119
120
121
122
123
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

124
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
125
126
127
128
129
130
131
132
133

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
134
135
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
136
137
138
139
140
141
142
143
144
145
146

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

147
148
149
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
150

151
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
152
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
153
    def test_grad_fp8_dot(self, m, n, k):
154
155
156
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

157
158
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
159

160
161
162
163
164
165
166
167
168
169
170
        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()

        def primitive_func(x, y, amax_list, scale_list):
            fp8_meta_pkg = FP8MetaPackage(
                amax_list[0],
                scale_list[0],
                amax_list[1],
                scale_list[1],
                amax_list[2],
                scale_list[2],
            )
171
172
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
173
174

        def ref_func(x, y):
175
            return jnp.mean(jnp.dot(x, y))
176

177
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
178
179
180
181
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

182
        for _ in range(3):
183
184
            primitive_out, (primitive_a_grad, primitive_b_grad, amax_list,
                            scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list)
185

186
187
188
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
189

190
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
191
    @pytest.mark.parametrize('m,n,k', [(256, 128, 512),
192
193
                                       (16384, 1024, 2816),
                                       (16384, 2816, 1024),
194
                                       (16384, 1024, 1024)])
195
    @pytest.mark.parametrize('activation_type', [('gelu', ),
196
197
                                                 ('gelu', 'linear'),
                                                 ('silu', ),
198
199
200
201
202
203
204
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear')])
205
    @pytest.mark.parametrize('use_bias', [True, False])
206
207
208
    def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str,
                                                                                         Callable]],
                                          use_bias: bool):
209
        """  N/a """
210
        key = jax.random.PRNGKey(0)
211
212
        subkeys = jax.random.split(key, 6)

213
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
214
215
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
216
217
218
219
220
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
221
222
            b1 = None
            b2 = None
223

224
225
        def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
                           scale_list_2):
226
227
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
228
            # out = ((x * y) + w) * z + v
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
245
            return jnp.mean(
246
247
248
249
250
251
                fused_layernorm_fp8_mlp(x,
                                        ln_s,
                                        None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2],
                                        "rmsnorm",
                                        activation_type=activation_type,
                                        use_bias=use_bias))
252

253
        def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
254
255
256
257
                                  kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
                                  amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray],
                                  scale_list_1: List[jnp.ndarray],
                                  scale_list_2: List[jnp.ndarray]) -> jnp.ndarray:
258
259
260
261
262
263
264

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

265
266
267
268
269
270
271
272
273
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
274

275
276
277
278
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

279
280
281
282
283
284
            x = jnp.split(linear_1_out, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
285
286
287

            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

288
289
290
291
292
293
294
295
296
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
            output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
297

298
299
300
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
301
302
303

            return output

304
        def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
305
            return jnp.mean(
306
307
                layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
                                      scale_list_2))
308
309
310
311
312

        value_n_grad_primitive_func = jit(
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

313
314
315
316
317
318
319
        _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
        _, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()

        ref_amax_list_1 = amax_list_1
        ref_scale_list_1 = scale_list_1
        ref_amax_list_2 = amax_list_2
        ref_scale_list_2 = scale_list_2
320

321
322
323
324
325
326
        primitive_amax_list_1 = amax_list_1
        primitive_scale_list_1 = scale_list_1
        primitive_amax_list_2 = amax_list_2
        primitive_scale_list_2 = scale_list_2

        primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
327

328
        # Convert str to index as str is not a valid type for JAX JIT
329
330
        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
331
332
333
334
                      ref_amax_list_1, ref_amax_list_2, ref_scale_list_1,
                      ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2,
                                                                ref_amax_list_1, ref_amax_list_2,
                                                                ref_scale_list_1, ref_scale_list_2)
335
336
337

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
338
339
340
341
342
                            primitive_k2_grad, primitive_b1_grad, primitive_b2_grad,
                            primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1,
                            primitive_scale_list_2) = value_n_grad_primitive_func(
                                a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2,
                                primitive_scale_list_1, primitive_scale_list_2)
343
344
345
346
347
348
349
350
351
352
353

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
354
355
356
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
357
358
359
        if use_bias:
            assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
                            jnp.asarray(ref_b2_grad, np.float32),
360
361
362
363
                            dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
                            jnp.asarray(ref_b1_grad, np.float32),
                            dtype=FP8Helper.BWD_DTYPE)
364

365

366
367
368
369
370
371
372
373
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


374
class TestActivationLu:
375

376
    def ref_func(self, x, activation_type):
377

378
379
380
381
382
383
384
385
        def ref_act_lu(inputs):
            x = jnp.split(inputs, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            return jnp.mean(x)
386

387
388
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
389

390
    def primitive_func(self, inputs):
391
        return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
392

393
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
394
395
396
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
397
398
399
400
401
402
403
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear') ])
404
    def test_activation_lu(self, random_inputs, activation_type):
405
        x = random_inputs
406
        x = jnp.repeat(x, len(activation_type), axis=1)
407
        self.activation_type = activation_type
408

409
        value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
410

411
412
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
413

414
415
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
416
417


418
class TestActivationLuFP8(TestActivationLu):
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
            activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv,
                                              FP8Helper.FWD_DTYPE, activation_type)
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
            ctx = (x)
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
            if len(self.activation_type) > 1: #gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = \
                dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
                                             FP8Helper.BWD_DTYPE, -1, activation_type)
                dbias = jnp.empty(x.shape[-1], x.dtype)
            else: #not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
                dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
                                             -1, -2, self.activation_type)
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
        value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
            jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

463
464

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
465
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
466
467
468
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
469
470
471
472
473
474
475
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear') ])
476
    def test_activation_lu(self, random_inputs, activation_type):
477
478
479
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
480
        self.activation_type = activation_type
481
        self.transpose_indices = (1, 2, 0)
482
483

        x = random_inputs
484
        x = jnp.repeat(x, len(activation_type), axis=1)
485

486
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
487
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
488

489
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
490
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
491
492
        if 'linear' not in activation_type:
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
493
494
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
495
                        jnp.transpose(ref_grad, self.transpose_indices),
496
                        dtype=FP8Helper.BWD_DTYPE)
497
498


499
500
501
502
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
503

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
            mean = 0.
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
            scale += 1.
        if bias is None:
            bias = 0.
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
537
538
539

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
540
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
541
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
                                        dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
            if ln_type == 'layernorm':
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            primitive_out, (primitive_dx, primitive_dgamma,
                            primitive_dbeta) = jitted_primitive(x, gamma, beta)
            reference_out, (reference_dx, reference_dgamma,
                            reference_dbeta) = jitted_reference(x, gamma, beta)

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
    @pytest.mark.parametrize('zero_centered_gamma', [True, False])
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
            if ln_type == 'layernorm':
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

624
625
626
627
628
629
630
631
632
633
634
            _, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()

            def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
                fp8_meta_pkg = FP8MetaPackage(
                    amax_list_1[0],
                    scale_list_1[0],
                    amax_list_1[1],
                    scale_list_1[1],
                    amax_list_1[2],
                    scale_list_1[2],
                )
635
636
637
638
639
640
641
642
                primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
                                                  zero_centered_gamma)
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

643
            value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
644
645
646
647
648
649
650
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
                      ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)

            for _ in range(3):
                primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
651
652
653
                                primitive_beta_grad, amax_list_1,
                                scale_list_1) = value_n_grad_primitive_func(
                                    a, b, gamma, beta, amax_list_1, scale_list_1)
654
655
656
657
658
659
660

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)