test_recipe_characteristics.py 13.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import unittest
6
from functools import partial
7
8
9
10
11

import flax
import jax
import jax.numpy as jnp
import numpy as np
12
from flax import linen as nn
13

14
from utils import assert_allclose, pytest_parametrize_wrapper
15
16
17
18
19
20
from transformer_engine.common.recipe import (
    DelayedScaling,
    MXFP8BlockScaling,
    Float8CurrentScaling,
    NVFP4BlockScaling,
)
21
from transformer_engine.common.recipe import Format as FP8Format
22
from transformer_engine.jax import autocast
23
from transformer_engine.jax.quantize import (
24
    get_quantize_config,
25
    get_supported_quantization_recipes,
26
    is_scaling_mode_supported,
27
28
    ScalingMode,
    update_collections,
29
    TensorSource,
30
31
    QuantizerFactory,
    QuantizeLayout,
32
)
33
from transformer_engine.jax.quantize.helper import _format2dtypes
34
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
35
from transformer_engine.jax.flax.module import TransformerEngineBase
36
37
from transformer_engine.jax import flax as te_flax
import transformer_engine.jax as te
38

39
40
41
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
42

43
44
SUPPORTED_RECIPES = get_supported_quantization_recipes()

45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
    """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""

    # Define a function with a custom VJP (vector-Jacobian product)
    @partial(jax.custom_vjp, nondiff_argnums=(1,))
    def quantizer_check(inner_quantizer_set, assertion_func, x):
        return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)

    def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
        assertion_func(inner_quantizer_set.x, TensorSource.X)
        assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
        assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
        return x

    def quantizer_check_bwd(ctx, g):
        return (g,)

    quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
    return quantizer_check(outer_quantizer_set, assertion_func, x)


class TestModule(TransformerEngineBase):
    """A simple module to test quantizer creation and reconstruction across VJP boundaries."""

    # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
    assertion_func: callable

    @nn.compact
    def __call__(self, x):
        quantizer_set = self.generate_quantizer_set()
        return quantizer_check_vjp(quantizer_set, self.assertion_func, x)


79
class TestHelper(unittest.TestCase):
80

81
    @unittest.skipIf(not is_fp8_supported, reason=reason)
82
83
84
85
86
87
88
89
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
90
        updated_state = update_collections({"test1": updated_val}, original_state)
91
92
93
94
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
95
        updated_state = update_collections({"test1": updated_val}, original_state)
96
97
98
99
100
101
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

102
    def _check_default_state(self):
103
        self.assertFalse(get_quantize_config().is_fp8_enabled())
104

105
106
107
108
109
110
    def _compare_delay_scaling(self, test):
        self.assertEqual(get_quantize_config().MARGIN, test.margin)
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
        self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
        self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
111

112
    def _compare_current_scaling(self, test):
113
114
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
115
116
117
118
119
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source),
                ScalingMode.CURRENT_TENSOR_SCALING,
            )
120
121

    def _compare_mxfp8_scaling(self, test):
122
123
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
124
125
126
127
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
            )
128

129
130
131
132
133
134
    def _compare_nvfp4_scaling(self, test):
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
        for tensor_source in TensorSource:
            target_scaling_mode = (
                ScalingMode.NVFP4_2D_SCALING
135
                if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
136
137
138
139
140
                else ScalingMode.NVFP4_1D_SCALING
            )
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
            )
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.assertEqual(
            get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
        )
        self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
        self.assertEqual(
            get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
        )

    def _compare_nvfp4_scaling_quantizers(self, test):
        """Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""

        def assertion_func(quantizer, tensor_source):
            if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
                self.assertIsNone(quantizer.stochastic_rounding_rng_state)
            else:
                self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)

            expected_rht = (
                quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
                and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
                and not test.disable_rht
            )
            self.assertEqual(quantizer.use_rht, expected_rht)

        x = jnp.ones((), dtype=jnp.float32)
        test_module = TestModule(assertion_func=assertion_func)
        param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
        rngs = {"params": param_key, "sr_rng": sr_key}
        variables = test_module.init(rngs, x)

        jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
172

173
    @unittest.skipIf(not is_fp8_supported, reason=reason)
174
    def test_autocast_delayed_scaling(self):
175
        self._check_default_state()
176

177
        with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
178
            self._check_default_state()
179

180
        self._check_default_state()
181

182
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
183
        with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
184
            self.assertTrue(get_quantize_config().is_fp8_enabled())
185
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
186

187
        self._check_default_state()
188

189
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
190
        with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
191
            self.assertTrue(get_quantize_config().is_fp8_enabled())
192
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
193

194
        self._check_default_state()
195

196
    @unittest.skipIf(not is_fp8_supported, reason=reason)
197
    def test_autocast_current_scaling(self):
198
        self._check_default_state()
199

200
        with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
201
            self._check_default_state()
202

203
        self._check_default_state()
204

205
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
206
        with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
207
            self.assertTrue(get_quantize_config().is_fp8_enabled())
208
209
            self._compare_current_scaling(cs)

210
        self._check_default_state()
211

212
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
213
        with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
214
            self.assertTrue(get_quantize_config().is_fp8_enabled())
215
216
            self._compare_current_scaling(cs)

217
        self._check_default_state()
218
219

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
220
    def test_autocast_mxfp8_block_scaling(self):
221
        self._check_default_state()
222

223
        with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
224
            self._check_default_state()
225

226
        self._check_default_state()
227

228
        bs = MXFP8BlockScaling()
229
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
230
            self.assertTrue(get_quantize_config().is_fp8_enabled())
231
232
            self._compare_mxfp8_scaling(bs)

233
        self._check_default_state()
234

235
    @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
236
    def test_autocast_nvfp4_block_scaling(self):
237
238
        self._check_default_state()

239
        with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
240
241
242
243
244
            self._check_default_state()

        self._check_default_state()

        bs = NVFP4BlockScaling()
245
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
246
            self.assertTrue(get_quantize_config().is_fp8_enabled())
247
            self._compare_nvfp4_scaling(bs)
248
249
250
251
252
253
254
255
256
257
258
            self._compare_nvfp4_scaling_quantizers(bs)

        bs = NVFP4BlockScaling(
            disable_stochastic_rounding=True,
            disable_rht=True,
            disable_2d_quantization=True,
        )
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
            self.assertTrue(get_quantize_config().is_fp8_enabled())
            self._compare_nvfp4_scaling(bs)
            self._compare_nvfp4_scaling_quantizers(bs)
259

260
        self._check_default_state()
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320


class TestJaxprAndHlo:
    """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""

    @pytest_parametrize_wrapper(
        "quantization_recipe",
        [
            quantization_recipe
            for quantization_recipe in SUPPORTED_RECIPES
            if isinstance(quantization_recipe, NVFP4BlockScaling)
        ],
    )
    def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
        """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""

        with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
            model = te_flax.LayerNormMLP(
                layernorm_type="rmsnorm",
                return_layernorm_output=False,
                intermediate_dropout_rate=0.0,
                dtype=jnp.bfloat16,
            )

            var_collect = model.init(
                jax.random.PRNGKey(0),
                jnp.ones((128, 128), dtype=jnp.bfloat16),
            )

            def loss_fn(x, rngs):
                return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])

            x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
            rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
            jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)

            rht_amax_eqns = [
                eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
            ]

            assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"

            def assert_param(index, tensor_name, expected_value: bool):
                if expected_value:
                    assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
                        f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
                        " reuse of amax as this tensor does not have a previous operation to fuse"
                        " with"
                    )
                else:
                    assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
                        f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
                        " reuse of amax"
                    )

            assert_param(0, "fwd ln+q", False)
            assert_param(1, "fwd act+q", False)
            # No previous op before incoming dgrad in the backward so amax is not reused
            assert_param(2, "bwd dgrad", True)
            assert_param(3, "bwd dact+q", False)