test_helper.py 10.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import unittest
6
from functools import partial
7
8
9
10
11

import flax
import jax
import jax.numpy as jnp
import numpy as np
12
from flax import linen as nn
13

14
from utils import assert_allclose
15
16
17
18
19
20
from transformer_engine.common.recipe import (
    DelayedScaling,
    MXFP8BlockScaling,
    Float8CurrentScaling,
    NVFP4BlockScaling,
)
21
from transformer_engine.common.recipe import Format as FP8Format
22
from transformer_engine.jax import autocast
23
from transformer_engine.jax.quantize import (
24
    get_quantize_config,
25
    is_scaling_mode_supported,
26
27
    ScalingMode,
    update_collections,
28
    TensorSource,
29
30
    QuantizerFactory,
    QuantizeLayout,
31
)
32
from transformer_engine.jax.quantize.helper import _format2dtypes
33
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
34
from transformer_engine.jax.flax.module import TransformerEngineBase
35

36
37
38
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
39

40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
    """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""

    # Define a function with a custom VJP (vector-Jacobian product)
    @partial(jax.custom_vjp, nondiff_argnums=(1,))
    def quantizer_check(inner_quantizer_set, assertion_func, x):
        return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)

    def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
        assertion_func(inner_quantizer_set.x, TensorSource.X)
        assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
        assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
        return x

    def quantizer_check_bwd(ctx, g):
        return (g,)

    quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
    return quantizer_check(outer_quantizer_set, assertion_func, x)


class TestModule(TransformerEngineBase):
    """A simple module to test quantizer creation and reconstruction across VJP boundaries."""

    # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
    assertion_func: callable

    @nn.compact
    def __call__(self, x):
        quantizer_set = self.generate_quantizer_set()
        return quantizer_check_vjp(quantizer_set, self.assertion_func, x)


74
class TestHelper(unittest.TestCase):
75

76
    @unittest.skipIf(not is_fp8_supported, reason=reason)
77
78
79
80
81
82
83
84
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
85
        updated_state = update_collections({"test1": updated_val}, original_state)
86
87
88
89
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
90
        updated_state = update_collections({"test1": updated_val}, original_state)
91
92
93
94
95
96
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

97
    def _check_default_state(self):
98
        self.assertFalse(get_quantize_config().is_fp8_enabled())
99

100
101
102
103
104
105
    def _compare_delay_scaling(self, test):
        self.assertEqual(get_quantize_config().MARGIN, test.margin)
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
        self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
        self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
106

107
    def _compare_current_scaling(self, test):
108
109
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
110
111
112
113
114
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source),
                ScalingMode.CURRENT_TENSOR_SCALING,
            )
115
116

    def _compare_mxfp8_scaling(self, test):
117
118
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
119
120
121
122
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
            )
123

124
125
126
127
128
129
    def _compare_nvfp4_scaling(self, test):
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
        for tensor_source in TensorSource:
            target_scaling_mode = (
                ScalingMode.NVFP4_2D_SCALING
130
                if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
131
132
133
134
135
                else ScalingMode.NVFP4_1D_SCALING
            )
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
            )
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        self.assertEqual(
            get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
        )
        self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
        self.assertEqual(
            get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
        )

    def _compare_nvfp4_scaling_quantizers(self, test):
        """Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""

        def assertion_func(quantizer, tensor_source):
            if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
                self.assertIsNone(quantizer.stochastic_rounding_rng_state)
            else:
                self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)

            expected_rht = (
                quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
                and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
                and not test.disable_rht
            )
            self.assertEqual(quantizer.use_rht, expected_rht)

        x = jnp.ones((), dtype=jnp.float32)
        test_module = TestModule(assertion_func=assertion_func)
        param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
        rngs = {"params": param_key, "sr_rng": sr_key}
        variables = test_module.init(rngs, x)

        jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
167

168
    @unittest.skipIf(not is_fp8_supported, reason=reason)
169
    def test_autocast_delayed_scaling(self):
170
        self._check_default_state()
171

172
        with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
173
            self._check_default_state()
174

175
        self._check_default_state()
176

177
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
178
        with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
179
            self.assertTrue(get_quantize_config().is_fp8_enabled())
180
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
181

182
        self._check_default_state()
183

184
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
185
        with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
186
            self.assertTrue(get_quantize_config().is_fp8_enabled())
187
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
188

189
        self._check_default_state()
190

191
    @unittest.skipIf(not is_fp8_supported, reason=reason)
192
    def test_autocast_current_scaling(self):
193
        self._check_default_state()
194

195
        with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
196
            self._check_default_state()
197

198
        self._check_default_state()
199

200
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
201
        with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
202
            self.assertTrue(get_quantize_config().is_fp8_enabled())
203
204
            self._compare_current_scaling(cs)

205
        self._check_default_state()
206

207
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
208
        with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
209
            self.assertTrue(get_quantize_config().is_fp8_enabled())
210
211
            self._compare_current_scaling(cs)

212
        self._check_default_state()
213
214

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
215
    def test_autocast_mxfp8_block_scaling(self):
216
        self._check_default_state()
217

218
        with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
219
            self._check_default_state()
220

221
        self._check_default_state()
222

223
        bs = MXFP8BlockScaling()
224
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
225
            self.assertTrue(get_quantize_config().is_fp8_enabled())
226
227
            self._compare_mxfp8_scaling(bs)

228
        self._check_default_state()
229

230
    @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
231
    def test_autocast_nvfp4_block_scaling(self):
232
233
        self._check_default_state()

234
        with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
235
236
237
238
239
            self._check_default_state()

        self._check_default_state()

        bs = NVFP4BlockScaling()
240
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
241
            self.assertTrue(get_quantize_config().is_fp8_enabled())
242
            self._compare_nvfp4_scaling(bs)
243
244
245
246
247
248
249
250
251
252
253
            self._compare_nvfp4_scaling_quantizers(bs)

        bs = NVFP4BlockScaling(
            disable_stochastic_rounding=True,
            disable_rht=True,
            disable_2d_quantization=True,
        )
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
            self.assertTrue(get_quantize_config().is_fp8_enabled())
            self._compare_nvfp4_scaling(bs)
            self._compare_nvfp4_scaling_quantizers(bs)
254

255
        self._check_default_state()