test_custom_call_compute.py 28.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
7
import functools
import operator
8
from typing import Callable, Sequence, Union
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
20
from transformer_engine.jax.fp8 import is_fp8_available
21
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
22
23
24
from transformer_engine.jax.mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions import act_lu_fp8, dact_lu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import dgated_act_lu_cast_transpose
25

26

Tim Moon's avatar
Tim Moon committed
27
28
29
30
31
32
33
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
34
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
35
36
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
37
is_fp8_supported, reason = is_fp8_available()
38

39
40


41
42
43
44
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
45
46
47
48
    if fn_or_string == 'quick_gelu':
        return lambda x: nn.gelu(x, approximate=True)
    if fn_or_string == 'squared_relu':
        return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
49
50
51
52
53
54
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

55
56
57

class TestFP8Dot:

58
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
59
    def test_qdq(self):
60
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
61
62
63
64
65
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

66
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
67
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
68

69
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
70
71
72
73
74
75
76
77

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

78
        primitive_out = type_safe_dot_general(a, b)
79
80
        ref_out = jnp.dot(a, b)

81
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
82

83
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
84
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
85
    def test_forward_fp8_randint(self, m, n, k):
86
87
88
89
90
91
92
93
94
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
95
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
96
97
98
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
99
100
101
102
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)

        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
103
104

        # calculate scale by amax
105
106
107
108
        fp8_metas_scale, fp8_metas_scale_inv = FP8Helper.update_fp8_scale(
            fp8_max, fp8_metas_amax, fp8_metas_scale)
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)
109

110
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
111
112
113
114
115
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

116
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
117
118
119
120
121
122
123
124
125

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
126
127
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
128
129
130
131
132
133
134
135
136
137
138

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

139
140
141
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
142

143
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
144
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
145
    def test_grad_fp8_dot(self, m, n, k):
146
147
148
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

149
150
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
151
152

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
153
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
154
155
156
157
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

158
159
160
161
162
        def primitive_func(x, y, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
            fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
163
164

        def ref_func(x, y):
165
            return jnp.mean(jnp.dot(x, y))
166

167
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5))
168
169
170
171
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

172
173
174
175
        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_b_grad, fp8_max, fp8_metas_amax,
                            fp8_metas_scale, fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, b, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)
176

177
178
179
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
180

181
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
182
    @pytest.mark.parametrize('m,n,k', [(256, 128, 512),
183
184
                                       (16384, 1024, 2816),
                                       (16384, 2816, 1024),
185
                                       (16384, 1024, 1024)])
186
    @pytest.mark.parametrize('activation_type', [('gelu', ),
187
188
                                                 ('gelu', 'linear'),
                                                 ('silu', ),
189
190
191
192
193
194
195
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear')])
196
197
    @pytest.mark.parametrize('use_bias', [True, False])
    def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
198
            activation_type: Sequence[Union[str, Callable]], use_bias: bool):
199
        """  N/a """
200
        key = jax.random.PRNGKey(0)
201
202
        subkeys = jax.random.split(key, 6)

203
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
204
205
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
206
207
208
209
210
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
211
212
            b1 = None
            b2 = None
213
214
215
216
217
218
219

        init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
        init_fp8_metas_amax = jnp.zeros(
            (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
        init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

220
        def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
221
                           fp8_metas_scale_inv):
222
223
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
224
            # out = ((x * y) + w) * z + v
225
226
            fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
227
228
229
            return jnp.mean(
                fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
                                activation_type = activation_type, use_bias = use_bias))
230
231


232
        def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
                                kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
                                fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                                scale_inv: jnp.ndarray) -> jnp.ndarray:

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

            fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

249
250
251
252
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

253
254
255
256
257
258
            x = jnp.split(linear_1_out, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
259
260
261
262
263
264
265
266
267

            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

            fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
            output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))

268
269
270
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
271
272
273
274
275
276

            return output

        def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
                     fp8_metas_scale_inv):
            return jnp.mean(
277
                layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                                    fp8_metas_scale_inv))

        value_n_grad_primitive_func = jit(
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

        ref_fp8_max = init_fp8_max
        ref_fp8_metas_amax = init_fp8_metas_amax
        ref_fp8_metas_scale = init_fp8_metas_scale
        ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        pri_fp8_max = init_fp8_max
        pri_fp8_metas_amax = init_fp8_metas_amax
        pri_fp8_metas_scale = init_fp8_metas_scale
        pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

294
        # Convert str to index as str is not a valid type for JAX JIT
295
296
297
298
299
        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
                      ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
                      ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
                          a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
300
                            ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                            primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
                            pri_fp8_metas_amax, pri_fp8_metas_scale,
                            pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
                                pri_fp8_metas_scale, pri_fp8_metas_scale_inv)

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
320
321
322
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
323
324
325
        if use_bias:
            assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
                            jnp.asarray(ref_b2_grad, np.float32),
326
327
328
329
                            dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
                            jnp.asarray(ref_b1_grad, np.float32),
                            dtype=FP8Helper.BWD_DTYPE)
330

331

332
333
334
335
336
337
338
339
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


340
class TestActivationLu:
341

342
343
344
345
346
347
348
349
350
    def ref_func(self, x, activation_type):
        def ref_act_lu(inputs):
            x = jnp.split(inputs, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            return jnp.mean(x)
351

352
353
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
354

355
356
    def primitive_func(self, inputs):
        return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
357

358
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
359
360
361
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
362
363
364
365
366
367
368
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear') ])
369
    def test_activation_lu(self, random_inputs, activation_type):
370
        x = random_inputs
371
        x = jnp.repeat(x, len(activation_type), axis=1)
372
        self.activation_type = activation_type
373

374
375
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)))
376

377
378
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
379

380
381
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
382
383


384
class TestActivationLuFP8(TestActivationLu):
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
            activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv,
                                              FP8Helper.FWD_DTYPE, activation_type)
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
            ctx = (x)
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
            if len(self.activation_type) > 1: #gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = \
                dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
                                             FP8Helper.BWD_DTYPE, -1, activation_type)
                dbias = jnp.empty(x.shape[-1], x.dtype)
            else: #not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
                dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
                                             -1, -2, self.activation_type)
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
        value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
            jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

429
430

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
431
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
432
433
434
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
435
436
437
438
439
440
441
                                                 ('silu', 'linear'),
                                                 ('relu',),
                                                 ('relu', 'linear'),
                                                 ('quick_gelu',),
                                                 ('quick_gelu', 'linear'),
                                                 ('squared_relu',),
                                                 ('squared_relu', 'linear') ])
442
    def test_activation_lu(self, random_inputs, activation_type):
443
444
445
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
446
        self.activation_type = activation_type
447
        self.transpose_indices = (1, 2, 0)
448
449

        x = random_inputs
450
        x = jnp.repeat(x, len(activation_type), axis=1)
451
452


453
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
454
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
455

456
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
457
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
458
459
        if 'linear' not in activation_type:
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
460
461
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
462
                        jnp.transpose(ref_grad, self.transpose_indices),
463
                        dtype=FP8Helper.BWD_DTYPE)
464
465


466
467
468
469
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
470

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
            mean = 0.
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
            scale += 1.
        if bias is None:
            bias = 0.
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
489
490
491

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
492
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
493
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
                                        dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
            if ln_type == 'layernorm':
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            primitive_out, (primitive_dx, primitive_dgamma,
                            primitive_dbeta) = jitted_primitive(x, gamma, beta)
            reference_out, (reference_dx, reference_dgamma,
                            reference_dbeta) = jitted_reference(x, gamma, beta)

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
    @pytest.mark.parametrize('zero_centered_gamma', [True, False])
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
            if ln_type == 'layernorm':
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

            fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
            fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
                                       jnp.float32)
            fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

            def primitive_func(x, y, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
                               fp8_metas_scale_inv):
                fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                              fp8_metas_scale_inv)
                primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
                                                  zero_centered_gamma)
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

            value_n_grad_primitive_func = value_and_grad(primitive_func, range(8))
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
                      ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)

            for _ in range(3):
                primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
                                primitive_beta_grad, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                    a, b, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                    fp8_metas_scale_inv)

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)