test_recipe_characteristics.py 19.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import unittest
6
from functools import partial
7
from abc import ABC, abstractmethod
8
9
10
11
12

import flax
import jax
import jax.numpy as jnp
import numpy as np
13
from flax import linen as nn
14

15
from utils import assert_allclose, pytest_parametrize_wrapper
16
from transformer_engine.common.recipe import (
17
    Recipe,
18
19
20
21
22
    DelayedScaling,
    MXFP8BlockScaling,
    Float8CurrentScaling,
    NVFP4BlockScaling,
)
23
from transformer_engine.common.recipe import Format as FP8Format
24
from transformer_engine.jax import autocast
25
from transformer_engine.jax.quantize import (
26
27
    get_global_quantize_recipe,
    get_quantize_config_with_recipe,
28
    get_supported_quantization_recipes,
29
    is_scaling_mode_supported,
30
31
    ScalingMode,
    update_collections,
32
    TensorSource,
33
    QuantizeLayout,
34
)
35
from transformer_engine.jax.quantize.helper import _format2dtypes
36
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
37
from transformer_engine.jax.flax.module import TransformerEngineBase
38
39
from transformer_engine.jax import flax as te_flax
import transformer_engine.jax as te
40

41
42
43
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
44

45
46
SUPPORTED_RECIPES = get_supported_quantization_recipes()

47

48
49
50
51
52
53
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
    """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""

    # Define a function with a custom VJP (vector-Jacobian product)
    @partial(jax.custom_vjp, nondiff_argnums=(1,))
    def quantizer_check(inner_quantizer_set, assertion_func, x):
54
        return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)[0]
55
56
57
58
59

    def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
        assertion_func(inner_quantizer_set.x, TensorSource.X)
        assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
        assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
60
        return x, (inner_quantizer_set,)
61

62
63
64
    def quantizer_check_bwd(assertion_func, ctx, g):
        (inner_quantizer_set,) = ctx
        return (inner_quantizer_set, g)
65
66
67
68
69
70
71
72
73
74

    quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
    return quantizer_check(outer_quantizer_set, assertion_func, x)


class TestModule(TransformerEngineBase):
    """A simple module to test quantizer creation and reconstruction across VJP boundaries."""

    # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
    assertion_func: callable
75
    direct_recipe: Recipe
76
77
78

    @nn.compact
    def __call__(self, x):
79
        quantizer_set = self.generate_quantizer_set(fp8_recipe=self.direct_recipe)
80
81
82
        return quantizer_check_vjp(quantizer_set, self.assertion_func, x)


83
class TestHelper(unittest.TestCase):
84

85
    @unittest.skipIf(not is_fp8_supported, reason=reason)
86
87
88
89
90
91
92
93
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
94
        updated_state = update_collections({"test1": updated_val}, original_state)
95
96
97
98
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
99
        updated_state = update_collections({"test1": updated_val}, original_state)
100
101
102
103
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


104
105
106
107
108
109
110
111
112
113
def assert_fp8_format(quantizer, tensor_source, fp8_format):
    if fp8_format == FP8Format.HYBRID:
        if tensor_source == TensorSource.DGRAD:
            assert quantizer.q_dtype == jnp.float8_e5m2
        else:
            assert quantizer.q_dtype == jnp.float8_e4m3fn
    elif fp8_format == FP8Format.E4M3:
        assert quantizer.q_dtype == jnp.float8_e4m3fn
    else:
        raise ValueError(f"Unsupported FP8 format: {fp8_format}")
114

115

116
117
class RecipeAssertionBase(ABC):
    """Base class for defining recipe assertions."""
118

119
120
121
    @abstractmethod
    def assert_context(self, ref_recipe, quantize_config):
        """Asserts that the quantize_config matches the expected properties from the reference recipe when the recipe is used with an autocast context.
122

123
124
125
126
127
        Args:
            ref_recipe: The reference quantization recipe.
            quantize_config: The quantization configuration to be checked.
        """
        pass
128

129
130
131
    @abstractmethod
    def assert_quantizers(self, ref_recipe, quantizer, tensor_source):
        """Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction.
132

133
134
135
136
137
138
        Args:
            ref_recipe: The reference quantization recipe.
            quantizer: The quantizer to be checked.
            tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD).
        """
        pass
139
140


141
class DelayedScalingRecipeAssertion(RecipeAssertionBase):
142

143
144
145
146
147
148
149
150
151
152
153
    def assert_context(self, ref_recipe, quantize_config):
        assert quantize_config.MARGIN == ref_recipe.margin
        assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
        assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
        assert quantize_config.AMAX_HISTORY_LEN == ref_recipe.amax_history_len
        assert quantize_config.AMAX_COMPUTE_ALGO.value == ref_recipe.amax_compute_algo
        for tensor_source in TensorSource:
            assert (
                quantize_config.get_scaling_mode(tensor_source)
                == ScalingMode.DELAYED_TENSOR_SCALING
            )
154

155
156
157
158
159
160
    def assert_quantizers(self, ref_recipe: DelayedScaling, quantizer, tensor_source):
        assert quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
        assert quantizer.margin == ref_recipe.margin
        assert quantizer.amax_compute_algo.value == ref_recipe.amax_compute_algo
        assert quantizer.amax_history.shape == (ref_recipe.amax_history_len,)
        assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
161
162


163
class CurrentScalingRecipeAssertion(RecipeAssertionBase):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
164

165
166
167
168
169
170
171
172
    def assert_context(self, ref_recipe, quantize_config):
        assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
        assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
        for tensor_source in TensorSource:
            assert (
                quantize_config.get_scaling_mode(tensor_source)
                == ScalingMode.CURRENT_TENSOR_SCALING
            )
173

174
175
176
    def assert_quantizers(self, ref_recipe: Float8CurrentScaling, quantizer, tensor_source):
        assert quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
        assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
177

178

179
class MXFP8RecipeAssertion(RecipeAssertionBase):
180

181
182
183
184
185
    def assert_context(self, ref_recipe, quantize_config):
        assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
        assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
        for tensor_source in TensorSource:
            assert quantize_config.get_scaling_mode(tensor_source) == ScalingMode.MXFP8_1D_SCALING
186

187
188
189
    def assert_quantizers(self, ref_recipe: MXFP8BlockScaling, quantizer, tensor_source):
        assert quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
        assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
190
191


192
class NVFP4RecipeAssertion(RecipeAssertionBase):
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def assert_context(self, ref_recipe, quantize_config):
        assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[0]
        assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[1]
        for tensor_source in TensorSource:
            target_scaling_mode = (
                ScalingMode.NVFP4_2D_SCALING
                if (not ref_recipe.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
                else ScalingMode.NVFP4_1D_SCALING
            )
            assert quantize_config.get_scaling_mode(tensor_source) == target_scaling_mode
        assert quantize_config.DISABLE_STOCHASTIC_ROUNDING == ref_recipe.disable_stochastic_rounding
        assert quantize_config.DISABLE_RHT == ref_recipe.disable_rht
        assert quantize_config.DISABLE_2D_QUANTIZATION == ref_recipe.disable_2d_quantization

    def assert_quantizers(self, ref_recipe: NVFP4BlockScaling, quantizer, tensor_source):
        if tensor_source == TensorSource.KERNEL and not ref_recipe.disable_2d_quantization:
            assert quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING
        else:
            assert quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING

        if ref_recipe.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
            assert quantizer.stochastic_rounding_rng_state is None
        else:
            assert quantizer.stochastic_rounding_rng_state is not None

        expected_rht = (
            quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
            and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
            and not ref_recipe.disable_rht
        )
        assert quantizer.use_rht == expected_rht
225
226


227
class TestFP8Functions(unittest.TestCase):
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def _check_default_state(self):
        self.assertEqual(get_global_quantize_recipe(), None)

    def _test_recipe(self, quantization_recipe: Recipe, cls: RecipeAssertionBase):
        """Tests a quantization recipe by verifying its behavior in both autocast and direct application contexts."""
        assert_context_func = cls().assert_context
        assert_quantizer_func = partial(cls().assert_quantizers, quantization_recipe)
        self._test_recipe_autocast(quantization_recipe, assert_context_func, assert_quantizer_func)
        self._test_recipe_direct(quantization_recipe, assert_quantizer_func)

    def _test_recipe_autocast(
        self, quantization_recipe, assert_context_func, assert_quantizer_func
    ):
        """Tests a quantization recipe within an autocast context by verifying the quantize config and quantizers in a test module."""
        self._check_default_state()
        with autocast(enabled=False, recipe=quantization_recipe, mesh_resource=MeshResource()):
245
            self._check_default_state()
246
247
248
249
250
        with autocast(enabled=True, recipe=quantization_recipe, mesh_resource=MeshResource()):
            quantize_config = self._get_global_quantize_config()
            assert_context_func(quantization_recipe, quantize_config)
            self._test_quantizer_in_model(assert_quantizer_func)
        self._check_default_state()
251

252
253
254
255
    def _test_recipe_direct(self, quantization_recipe, assert_quantizer_func):
        """Tests a quantization recipe by directly passing it to a test module and verifying the quantizers."""
        self._check_default_state()
        self._test_quantizer_in_model(assert_quantizer_func, direct_recipe=quantization_recipe)
256
        self._check_default_state()
257

258
259
    def _test_quantizer_in_model(self, assert_quantizer_func, direct_recipe=None):
        """Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary.
260

261
262
263
264
265
266
267
268
269
        Args:
            assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None.
            direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts.
        """
        x = jnp.ones((), dtype=jnp.float32)
        test_module = TestModule(assertion_func=assert_quantizer_func, direct_recipe=direct_recipe)
        param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
        rngs = {"params": param_key, "sr_rng": sr_key}
        variables = test_module.init(rngs, x)
270

271
        jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
272

273
274
275
276
277
278
279
280
    def _get_global_quantize_config(self):
        quantization_recipe = get_global_quantize_recipe()
        assert quantization_recipe is not None, "No global quantization recipe set"
        quantize_config = get_quantize_config_with_recipe(quantization_recipe)
        assert (
            quantize_config.is_fp8_enabled()
        ), "Quantization not enabled in global quantize config"
        return quantize_config
281

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    @unittest.skipIf(not is_fp8_supported, reason=reason)
    def test_autocast_delayed_scaling(self):
        self._test_recipe(
            quantization_recipe=DelayedScaling(),
            cls=DelayedScalingRecipeAssertion,
        )
        self._test_recipe(
            quantization_recipe=DelayedScaling(
                margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1
            ),
            cls=DelayedScalingRecipeAssertion,
        )
        self._test_recipe(
            quantization_recipe=DelayedScaling(
                margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1
            ),
            cls=DelayedScalingRecipeAssertion,
        )
300

301
302
303
304
305
306
307
308
309
310
311
312
313
314
    @unittest.skipIf(not is_fp8_supported, reason=reason)
    def test_autocast_current_scaling(self):
        self._test_recipe(
            quantization_recipe=Float8CurrentScaling(),
            cls=CurrentScalingRecipeAssertion,
        )
        self._test_recipe(
            quantization_recipe=Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3),
            cls=CurrentScalingRecipeAssertion,
        )
        self._test_recipe(
            quantization_recipe=Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID),
            cls=CurrentScalingRecipeAssertion,
        )
315

316
317
318
319
320
    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
    def test_autocast_mxfp8_block_scaling(self):
        self._test_recipe(
            quantization_recipe=MXFP8BlockScaling(),
            cls=MXFP8RecipeAssertion,
321
        )
322

323
324
325
326
327
328
329
330
331
332
333
334
335
336
    @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
    def test_autocast_nvfp4_block_scaling(self):
        self._test_recipe(
            quantization_recipe=NVFP4BlockScaling(),
            cls=NVFP4RecipeAssertion,
        )
        self._test_recipe(
            quantization_recipe=NVFP4BlockScaling(
                disable_stochastic_rounding=True,
                disable_rht=True,
                disable_2d_quantization=True,
            ),
            cls=NVFP4RecipeAssertion,
        )
337
338
339
340
341


class TestJaxprAndHlo:
    """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""

342
343
344
    def _generate_jaxpr_for_layernorm_mlp_fwd_bwd(self, quantization_recipe, ln_mlp_kwargs=None):
        """Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe."""
        ln_mlp_kwargs = ln_mlp_kwargs or {}
345
346
347
348
349
350
        with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
            model = te_flax.LayerNormMLP(
                layernorm_type="rmsnorm",
                return_layernorm_output=False,
                intermediate_dropout_rate=0.0,
                dtype=jnp.bfloat16,
351
                **ln_mlp_kwargs,
352
353
354
355
356
357
358
359
360
361
362
363
            )

            var_collect = model.init(
                jax.random.PRNGKey(0),
                jnp.ones((128, 128), dtype=jnp.bfloat16),
            )

            def loss_fn(x, rngs):
                return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])

            x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
            rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
            return jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)

    @pytest_parametrize_wrapper(
        "quantization_recipe",
        [
            quantization_recipe
            for quantization_recipe in SUPPORTED_RECIPES
            if isinstance(quantization_recipe, NVFP4BlockScaling)
        ],
    )
    def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
        """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""

        jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe)

        rht_amax_eqns = [
            eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
        ]

        assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"

        def assert_param(index, tensor_name, expected_value: bool):
            if expected_value:
                assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
                    f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
                    " reuse of amax as this tensor does not have a previous operation to fuse"
                    " with"
                )
            else:
                assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
                    f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
                    " reuse of amax"
                )

        assert_param(0, "fwd ln+q", False)
        assert_param(1, "fwd act+q", False)
        # No previous op before incoming dgrad in the backward so amax is not reused
        assert_param(2, "bwd dgrad", True)
        assert_param(3, "bwd dact+q", False)

    @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
    @pytest_parametrize_wrapper(
        "quantization_checkpoint_name",
        [None, "quantization", "some_arbitrary_user_checkpoint_name"],
    )
    def test_recipe_supports_quantization_checkpointing(
        self, quantization_recipe, quantization_checkpoint_name
    ):
        """Tests that all supported quantization recipes correctly use checkpoint_name."""

        kwargs = {
            "quantization_checkpoint_name": quantization_checkpoint_name,
        }
        jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe, kwargs)

        checkpoint_name_eqns = [
            eqn
            for eqn in jaxpr.jaxpr.eqns
            if eqn.primitive.name == "name" and eqn.params["name"] == quantization_checkpoint_name
        ]

        if quantization_checkpoint_name is None:
            assert len(checkpoint_name_eqns) == 0, (
                "Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got"
                f" {len(checkpoint_name_eqns)}"
            )
            return

        # 12 checkpointed values:
        # - Fwd pass:
        #   - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward
        #   - Kernel Q -> 3 possible output tensors that will be used in the backward
        #   - Input Activation+Q -> 3 possible output tensors that will be used in the backward
        #   - Kernel Q -> 3 possible output tensors that will be used in the backward
        expected_checkpoint_eqn_count = 12

        assert len(checkpoint_name_eqns) == expected_checkpoint_eqn_count, (
            f"Expected {expected_checkpoint_eqn_count} checkpoint_name eqns when"
            f" quantization_checkpoint_name is set, got {len(checkpoint_name_eqns)}"
        )