test_custom_call_compute.py 20.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# See LICENSE for license information.

import functools
import operator

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn

16
from utils import assert_allclose
17
18
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
19
20
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
21
from transformer_engine.jax.fp8 import is_fp8_available
22
from transformer_engine.jax.layernorm import layernorm
23
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
24

Tim Moon's avatar
Tim Moon committed
25
26
27
28
29
30
31
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
32
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
33
34
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
35
is_fp8_supported, reason = is_fp8_available()
36
37


38
39
40
41
42
43
44
45
46
47
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


48
49
class TestFP8Dot:

50
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
51
    def test_qdq(self):
52
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
53
54
55
56
57
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

58
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
59
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
60

61
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
62
63
64
65
66
67
68
69

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

70
        primitive_out = type_safe_dot_general(a, b)
71
72
        ref_out = jnp.dot(a, b)

73
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
74

75
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
76
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
77
    def test_forward_fp8_randint(self, m, n, k):
78
79
80
81
82
83
84
85
86
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
87
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
88
89
90
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
91
92
93
94
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)

        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
95
96

        # calculate scale by amax
97
98
99
100
        fp8_metas_scale, fp8_metas_scale_inv = FP8Helper.update_fp8_scale(
            fp8_max, fp8_metas_amax, fp8_metas_scale)
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)
101

102
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
103
104
105
106
107
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

108
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
109
110
111
112
113
114
115
116
117

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
118
119
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
120
121
122
123
124
125
126
127
128
129
130

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

131
132
133
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
134

135
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
136
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
137
    def test_grad_fp8_dot(self, m, n, k):
138
139
140
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

141
142
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
143
144

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
145
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
146
147
148
149
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

150
151
152
153
154
        def primitive_func(x, y, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
            fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
155
156

        def ref_func(x, y):
157
            return jnp.mean(jnp.dot(x, y))
158

159
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5))
160
161
162
163
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

164
165
166
167
        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_b_grad, fp8_max, fp8_metas_amax,
                            fp8_metas_scale, fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, b, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)
168

169
170
171
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
172

173
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
174
175
    @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
                                       (16384, 1024, 1024)])
176
    def test_grad_ln_geglu_fp8_mlp(self, m, n, k):
177
178
179
180
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)
        activations = ('gelu', 'linear')

181
182
183
184
185
186
187
188
189
190
191
192
193
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
        s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16)

        init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
        init_fp8_metas_amax = jnp.zeros(
            (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
        init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

        def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
                           fp8_metas_scale_inv):
194
195
196
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
            # out = (x * y) * z
197
198
            fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
199
            return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
200
201
202
203
204
205
206
207
208
209
210

        def _convert_to_activation_function(fn_or_string):
            """Convert a string to an activation function."""
            if fn_or_string == 'linear':
                return lambda x: x
            if isinstance(fn_or_string, str):
                return getattr(nn, fn_or_string)
            if callable(fn_or_string):
                return fn_or_string
            raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

211
212
213
214
215
        def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
                                 kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
                                 scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray:

            x = jnp.asarray(x, jnp.float32)
216
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
217
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
218
219
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)
220
221

            fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
222
223
224
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
225
226
227
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

            x = jnp.split(linear_1_out, len(activations), axis=-2)
228
229
230
231
232
            acts = []
            for idx, act_fn in enumerate(activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
233
234
235
            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

            fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
236
237
238
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
239
            output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
240
241
            return output

242
        def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
243
            return jnp.mean(
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                     fp8_metas_scale_inv))

        value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7)))

        ref_fp8_max = init_fp8_max
        ref_fp8_metas_amax = init_fp8_metas_amax
        ref_fp8_metas_scale = init_fp8_metas_scale
        ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        pri_fp8_max = init_fp8_max
        pri_fp8_metas_amax = init_fp8_metas_amax
        pri_fp8_metas_scale = init_fp8_metas_scale
        pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max,
                      ref_fp8_metas_amax, ref_fp8_metas_scale,
                      ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
                          a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
                          ref_fp8_metas_scale_inv)

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                            primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
                            pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
                                pri_fp8_metas_scale_inv)

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
275
276
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
277
                        dtype=FP8Helper.BWD_DTYPE)
278
279
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
280
                        dtype=FP8Helper.BWD_DTYPE)
281
282
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
283
                        dtype=FP8Helper.BWD_DTYPE)
284
285
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
286
                        dtype=FP8Helper.BWD_DTYPE)
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301


@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


class TestGatedGeLu:

    def ref_func(self, inputs):

        def jax_gated_gelu(x):
302
            x = jnp.split(x, 2, axis=-2)
303
304
            acts = [jax.nn.gelu(x[0]), x[1]]
            x = functools.reduce(operator.mul, acts)
305
            x = jnp.asarray(jnp.squeeze(x, -2), jnp.bfloat16)
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            return x

        func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x))))
        return func(inputs)

    def prim_func(self, inputs):

        @jax.custom_vjp
        def primitive(x):
            out, _ = primitive_fwd(x)
            return out

        def primitive_fwd(x):
            out = gated_gelu(x)
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            out = dgated_gelu(g, x)
            return (out,)

        primitive.defvjp(primitive_fwd, primitive_bwd)
329
        func = value_and_grad(lambda x: jnp.mean(primitive(x)))
330
331
        return func(inputs)

332
    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
333
334
335
336
337
    def test_gated_gelu(self, random_inputs):
        x = random_inputs
        prim_out, prim_grad = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

338
339
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
340
341
342
343
344
345
346
347
348
349
350
351


class TestGatedGeLuFP8(TestGatedGeLu):

    def prim_func(self, inputs):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        no_use = jnp.zeros(1, jnp.float32)

        @jax.custom_vjp
        def primitive(x, y, z):
352
            out = primitive_fwd(x)
353
354
            return out

355
356
357
        def primitive_fwd(x, y, z):
            out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
            out = dequantize(out, x.dtype, scale_inv)
358
359
360
361
362
363
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv,
364
365
366
                                                                      jnp.float8_e5m2, -1)
            dgelu = dequantize(dgelu, x.dtype, scale_inv)
            dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
367
368
369
            return dgelu, dgelu_trans, amax_out

        primitive.defvjp(primitive_fwd, primitive_bwd)
370
        func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2))
371
372
373

        return func(inputs, no_use, no_use)

374
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
375
    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
376
377
378
379
380
381
382
383
384
    def test_gated_gelu(self, random_inputs):
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)

        x = random_inputs
        prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

385
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
386
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
387
388
389
390
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
                        jnp.transpose(ref_grad, (1, 2, 0)),
                        dtype=FP8Helper.BWD_DTYPE)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420


class TestRMSNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_forward_backward(self, n, hidden, dtype):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
        scale = jnp.asarray(scale, dtype)
        epsilon = 1e-6

        def reference_rmsnorm(x, scale):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
            return y * scale

        jitted_primitive = jit(
            value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))

        jitted_reference = jit(
            value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))

        primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
        reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)

421
422
423
        assert_allclose(primitive_out, reference_out, dtype=dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
424
425
426
427
428
429


class TestLayerNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
430
431
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
    def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
432
433
434
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

435
436
437
        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
        scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
438
        scale = jnp.asarray(scale, dtype)
439
        bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
440
441
442
        bias = jnp.asarray(bias, dtype)
        epsilon = 1e-6

443
444
445
446
447
        def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
448
            # Align TE implementation
449
450
451
452
453
454
455
456
            if zero_centered_gamma:
                return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
            return jnp.asarray(normed_input * scale + bias).astype(x.dtype)

        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)
457
458

        jitted_primitive = jit(
459
460
461
462
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
                (0, 1, 2)))
463
464

        jitted_reference = jit(
465
466
467
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
468
469
470
471
472
473

        primitive_out, (primitive_dx, primitive_dgamma,
                        primitive_dbeta) = jitted_primitive(x, scale, bias)
        reference_out, (reference_dx, reference_dgamma,
                        reference_dbeta) = jitted_reference(x, scale, bias)

474
475
476
477
        assert_allclose(primitive_out, reference_out, dtype=dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
        assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)