test_custom_call_compute.py 22.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

import functools
import operator
7
from typing import Callable, Sequence, Union
8
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
20
from transformer_engine.jax.fp8 import is_fp8_available
21
from transformer_engine.jax.layernorm import layernorm
22
23
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp

24

Tim Moon's avatar
Tim Moon committed
25
26
27
28
29
30
31
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
32
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
33
34
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
35
is_fp8_supported, reason = is_fp8_available()
36

37
38
39
40
41
42
43
44
45
46
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

47

48
49
50
51
52
53
54
55
56
57
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


58
59
class TestFP8Dot:

60
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
61
    def test_qdq(self):
62
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
63
64
65
66
67
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

68
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
69
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
70

71
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
72
73
74
75
76
77
78
79

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

80
        primitive_out = type_safe_dot_general(a, b)
81
82
        ref_out = jnp.dot(a, b)

83
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
84

85
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
86
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
87
    def test_forward_fp8_randint(self, m, n, k):
88
89
90
91
92
93
94
95
96
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
97
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
98
99
100
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
101
102
103
104
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)

        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
105
106

        # calculate scale by amax
107
108
109
110
        fp8_metas_scale, fp8_metas_scale_inv = FP8Helper.update_fp8_scale(
            fp8_max, fp8_metas_amax, fp8_metas_scale)
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)
111

112
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
113
114
115
116
117
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

118
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
119
120
121
122
123
124
125
126
127

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
128
129
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
130
131
132
133
134
135
136
137
138
139
140

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

141
142
143
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
144

145
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
146
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
147
    def test_grad_fp8_dot(self, m, n, k):
148
149
150
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

151
152
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
153
154

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
155
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
156
157
158
159
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

160
161
162
163
164
        def primitive_func(x, y, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
            fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
165
166

        def ref_func(x, y):
167
            return jnp.mean(jnp.dot(x, y))
168

169
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5))
170
171
172
173
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

174
175
176
177
        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_b_grad, fp8_max, fp8_metas_amax,
                            fp8_metas_scale, fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, b, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)
178

179
180
181
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
182

183
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
184
185
186
    @pytest.mark.parametrize('m,n,k', [(128, 256, 512),
                                       (16384, 1024, 2816),
                                       (16384, 2816, 1024),
187
                                       (16384, 1024, 1024)])
188
    @pytest.mark.parametrize('activation_type', [('gelu', ),
189
190
191
                                                 ('gelu', 'linear'),
                                                 ('silu', ),
                                                 ('silu', 'linear')])
192
193
    @pytest.mark.parametrize('use_bias', [True, False])
    def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
194
            activation_type: Sequence[Union[str, Callable]], use_bias: bool):
195
        """  N/a """
196
        key = jax.random.PRNGKey(0)
197
198
        subkeys = jax.random.split(key, 6)

199
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
200
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
201
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
202
203
204
205
206
207
208
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16)
209
210
211
212
213
214
215

        init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
        init_fp8_metas_amax = jnp.zeros(
            (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
        init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

216
        def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
217
                           fp8_metas_scale_inv):
218
219
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
220
            # out = ((x * y) + w) * z + v
221
222
            fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
223
224
225
            return jnp.mean(
                fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
                                activation_type = activation_type, use_bias = use_bias))
226
227


228
        def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                                kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
                                fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                                scale_inv: jnp.ndarray) -> jnp.ndarray:

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

            fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

245
246
247
248
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

249
250
251
252
253
254
            x = jnp.split(linear_1_out, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
255
256
257
258
259
260
261
262
263

            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

            fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
            output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))

264
265
266
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
267
268
269
270
271
272

            return output

        def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
                     fp8_metas_scale_inv):
            return jnp.mean(
273
                layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                                    fp8_metas_scale_inv))

        value_n_grad_primitive_func = jit(
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

        ref_fp8_max = init_fp8_max
        ref_fp8_metas_amax = init_fp8_metas_amax
        ref_fp8_metas_scale = init_fp8_metas_scale
        ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        pri_fp8_max = init_fp8_max
        pri_fp8_metas_amax = init_fp8_metas_amax
        pri_fp8_metas_scale = init_fp8_metas_scale
        pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

290
        # Convert str to index as str is not a valid type for JAX JIT
291
292
293
294
295
        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
                      ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
                      ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
                          a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
296
                            ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                            primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
                            pri_fp8_metas_amax, pri_fp8_metas_scale,
                            pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
                                pri_fp8_metas_scale, pri_fp8_metas_scale_inv)

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
319
320
321
322
323
324
325
326
        if use_bias:
            assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
                            jnp.asarray(ref_b1_grad, np.float32),
                            dtype=jnp.bfloat16)
            assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
                            jnp.asarray(ref_b2_grad, np.float32),
                            dtype=jnp.bfloat16)

327

328
329
330
331
332
333
334
335
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


336
class TestActivationLu:
337

338
339
340
341
342
343
344
345
346
    def ref_func(self, x, activation_type):
        def ref_act_lu(inputs):
            x = jnp.split(inputs, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            return jnp.mean(x)
347

348
349
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
350

351
352
    def primitive_func(self, inputs):
        return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
353
354

    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
355
356
357
358
359
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
                                                 ('silu', 'linear')])
    def test_activation_lu(self, random_inputs, activation_type):
360
        x = random_inputs
361
        self.activation_type = activation_type
362

363
364
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)))
365

366
367
368
369
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
        """ prim_grad, = prim_grad """
        """ ref_grad, = ref_grad """
370

371
372
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
373
374


375
class TestActivationLuFP8(TestActivationLu):
376

377
378
379
380
381
382
    def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv):
        return jnp.mean(
            activation_lu_fp8(inputs,
                              amax, scale, scale_inv,
                              jnp.float8_e4m3fn, jnp.float8_e5m2,
                              activation_type = self.activation_type))
383
384
385

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
386
387
388
389
390
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
                                                 ('silu', 'linear')])
    def test_activation_lu(self, random_inputs, activation_type):
391
392
393
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
394
        self.activation_type = activation_type
395
396

        x = random_inputs
397

398
        value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
399

400
401
402
        transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
        dx_trans_no_use = jnp.zeros([x.shape[i] for i in transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.zeros(x.shape[-1], dtype=x.dtype)
403

404
405
406
407
        prim_out, (prim_grad, prim_grad_trans, dbias, amax, _, _) = \
            value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use,
                                    self.amax, self.scale, self.scale_inv)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
408

409
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
410
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
411
412
        if 'linear' not in activation_type:
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
413
414
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
415
                        jnp.transpose(ref_grad, transpose_indices),
416
                        dtype=FP8Helper.BWD_DTYPE)
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446


class TestRMSNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_forward_backward(self, n, hidden, dtype):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
        scale = jnp.asarray(scale, dtype)
        epsilon = 1e-6

        def reference_rmsnorm(x, scale):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
            return y * scale

        jitted_primitive = jit(
            value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))

        jitted_reference = jit(
            value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))

        primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
        reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)

447
448
449
        assert_allclose(primitive_out, reference_out, dtype=dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
450
451
452
453
454
455


class TestLayerNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
456
457
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
    def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
458
459
460
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

461
462
463
        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
        scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
464
        scale = jnp.asarray(scale, dtype)
465
        bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
466
467
468
        bias = jnp.asarray(bias, dtype)
        epsilon = 1e-6

469
470
471
472
473
        def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
474
            # Align TE implementation
475
476
477
478
479
480
481
482
            if zero_centered_gamma:
                return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
            return jnp.asarray(normed_input * scale + bias).astype(x.dtype)

        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)
483
484

        jitted_primitive = jit(
485
486
487
488
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
                (0, 1, 2)))
489
490

        jitted_reference = jit(
491
492
493
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
494
495
496
497
498
499

        primitive_out, (primitive_dx, primitive_dgamma,
                        primitive_dbeta) = jitted_primitive(x, scale, bias)
        reference_out, (reference_dx, reference_dgamma,
                        reference_dbeta) = jitted_reference(x, scale, bias)

500
501
502
503
        assert_allclose(primitive_out, reference_out, dtype=dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
        assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)