test_custom_call_compute.py 25.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
7
import functools
import operator
8
from typing import Callable, Sequence, Union
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

17
from utils import assert_allclose
18
19
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
20
from transformer_engine.jax.fp8 import is_fp8_available
21
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
22
23
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp

24

Tim Moon's avatar
Tim Moon committed
25
26
27
28
29
30
31
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
32
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
33
34
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
35
is_fp8_supported, reason = is_fp8_available()
36

37
38
39
40
41
42
43
44
45
46
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

47
48
49

class TestFP8Dot:

50
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
51
    def test_qdq(self):
52
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
53
54
55
56
57
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

58
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
59
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
60

61
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
62
63
64
65
66
67
68
69

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

70
        primitive_out = type_safe_dot_general(a, b)
71
72
        ref_out = jnp.dot(a, b)

73
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
74

75
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
76
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
77
    def test_forward_fp8_randint(self, m, n, k):
78
79
80
81
82
83
84
85
86
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
87
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
88
89
90
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
91
92
93
94
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)

        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
95
96

        # calculate scale by amax
97
98
99
100
        fp8_metas_scale, fp8_metas_scale_inv = FP8Helper.update_fp8_scale(
            fp8_max, fp8_metas_amax, fp8_metas_scale)
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)
101

102
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
103
104
105
106
107
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

108
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
109
110
111
112
113
114
115
116
117

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
118
119
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
120
121
122
123
124
125
126
127
128
129
130

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

131
132
133
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
134

135
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
136
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
137
    def test_grad_fp8_dot(self, m, n, k):
138
139
140
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

141
142
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
143
144

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
145
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
146
147
148
149
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

150
151
152
153
154
        def primitive_func(x, y, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
            fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
155
156

        def ref_func(x, y):
157
            return jnp.mean(jnp.dot(x, y))
158

159
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5))
160
161
162
163
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

164
165
166
167
        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_b_grad, fp8_max, fp8_metas_amax,
                            fp8_metas_scale, fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, b, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)
168

169
170
171
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
172

173
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
174
175
176
    @pytest.mark.parametrize('m,n,k', [(128, 256, 512),
                                       (16384, 1024, 2816),
                                       (16384, 2816, 1024),
177
                                       (16384, 1024, 1024)])
178
    @pytest.mark.parametrize('activation_type', [('gelu', ),
179
180
181
                                                 ('gelu', 'linear'),
                                                 ('silu', ),
                                                 ('silu', 'linear')])
182
183
    @pytest.mark.parametrize('use_bias', [True, False])
    def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
184
            activation_type: Sequence[Union[str, Callable]], use_bias: bool):
185
        """  N/a """
186
        key = jax.random.PRNGKey(0)
187
188
        subkeys = jax.random.split(key, 6)

189
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
190
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
191
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
192
193
194
195
196
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
197
198
            b1 = None
            b2 = None
199
200
201
202
203
204
205

        init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
        init_fp8_metas_amax = jnp.zeros(
            (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
        init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

206
        def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
207
                           fp8_metas_scale_inv):
208
209
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
210
            # out = ((x * y) + w) * z + v
211
212
            fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
213
214
215
            return jnp.mean(
                fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
                                activation_type = activation_type, use_bias = use_bias))
216
217


218
        def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                                kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
                                fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                                scale_inv: jnp.ndarray) -> jnp.ndarray:

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

            fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

235
236
237
238
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

239
240
241
242
243
244
            x = jnp.split(linear_1_out, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
245
246
247
248
249
250
251
252
253

            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

            fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
            output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))

254
255
256
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
257
258
259
260
261
262

            return output

        def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
                     fp8_metas_scale_inv):
            return jnp.mean(
263
                layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                                    fp8_metas_scale_inv))

        value_n_grad_primitive_func = jit(
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

        ref_fp8_max = init_fp8_max
        ref_fp8_metas_amax = init_fp8_metas_amax
        ref_fp8_metas_scale = init_fp8_metas_scale
        ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        pri_fp8_max = init_fp8_max
        pri_fp8_metas_amax = init_fp8_metas_amax
        pri_fp8_metas_scale = init_fp8_metas_scale
        pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

280
        # Convert str to index as str is not a valid type for JAX JIT
281
282
283
284
285
        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
                      ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
                      ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
                          a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
286
                            ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                            primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
                            pri_fp8_metas_amax, pri_fp8_metas_scale,
                            pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
                                pri_fp8_metas_scale, pri_fp8_metas_scale_inv)

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
306
307
308
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
309
310
311
        if use_bias:
            assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
                            jnp.asarray(ref_b2_grad, np.float32),
312
313
314
315
                            dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
                            jnp.asarray(ref_b1_grad, np.float32),
                            dtype=FP8Helper.BWD_DTYPE)
316

317

318
319
320
321
322
323
324
325
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


326
class TestActivationLu:
327

328
329
330
331
332
333
334
335
336
    def ref_func(self, x, activation_type):
        def ref_act_lu(inputs):
            x = jnp.split(inputs, len(activation_type), axis=-2)
            acts = []
            for idx, act_fn in enumerate(activation_type):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            return jnp.mean(x)
337

338
339
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
340

341
342
    def primitive_func(self, inputs):
        return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
343

344
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
345
346
347
348
349
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
                                                 ('silu', 'linear')])
    def test_activation_lu(self, random_inputs, activation_type):
350
        x = random_inputs
351
        x = jnp.repeat(x, len(activation_type), axis=1)
352
        self.activation_type = activation_type
353

354
355
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)))
356

357
358
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
359

360
361
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
362
363


364
class TestActivationLuFP8(TestActivationLu):
365

366
367
368
369
370
371
    def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv):
        return jnp.mean(
            activation_lu_fp8(inputs,
                              amax, scale, scale_inv,
                              jnp.float8_e4m3fn, jnp.float8_e5m2,
                              activation_type = self.activation_type))
372
373

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
374
    @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
375
376
377
378
379
    @pytest.mark.parametrize('activation_type', [('gelu',),
                                                 ('gelu', 'linear'),
                                                 ('silu',),
                                                 ('silu', 'linear')])
    def test_activation_lu(self, random_inputs, activation_type):
380
381
382
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
383
        self.activation_type = activation_type
384
385

        x = random_inputs
386
        x = jnp.repeat(x, len(activation_type), axis=1)
387

388
        value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
389

390
391
392
        transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
        dx_trans_no_use = jnp.zeros([x.shape[i] for i in transpose_indices], dtype=x.dtype)
        dbias_no_use = jnp.zeros(x.shape[-1], dtype=x.dtype)
393

394
395
396
397
        prim_out, (prim_grad, prim_grad_trans, dbias, amax, _, _) = \
            value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use,
                                    self.amax, self.scale, self.scale_inv)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
398

399
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
400
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
401
402
        if 'linear' not in activation_type:
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
403
404
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
405
                        jnp.transpose(ref_grad, transpose_indices),
406
                        dtype=FP8Helper.BWD_DTYPE)
407
408


409
410
411
412
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
            mean = 0.
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
            scale += 1.
        if bias is None:
            bias = 0.
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
432
433
434

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
435
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
436
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
                                        dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
            if ln_type == 'layernorm':
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
                    (0, 1, 2)))

            primitive_out, (primitive_dx, primitive_dgamma,
                            primitive_dbeta) = jitted_primitive(x, gamma, beta)
            reference_out, (reference_dx, reference_dgamma,
                            reference_dbeta) = jitted_reference(x, gamma, beta)

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
    @pytest.mark.parametrize('zero_centered_gamma', [True, False])
    @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
        if ln_type == 'rmsnorm' and zero_centered_gamma:
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

        with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
                          ) if expect_assert else nullcontext():
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
            if ln_type == 'layernorm':
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

            fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
            fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
                                       jnp.float32)
            fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

            def primitive_func(x, y, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
                               fp8_metas_scale_inv):
                fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                              fp8_metas_scale_inv)
                primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
                                                  zero_centered_gamma)
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

            value_n_grad_primitive_func = value_and_grad(primitive_func, range(8))
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
                      ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)

            for _ in range(3):
                primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
                                primitive_beta_grad, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                    a, b, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                    fp8_metas_scale_inv)

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)