test_helper.py 6.99 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
14
15
16
17
18
from transformer_engine.common.recipe import (
    DelayedScaling,
    MXFP8BlockScaling,
    Float8CurrentScaling,
    NVFP4BlockScaling,
)
19
from transformer_engine.common.recipe import Format as FP8Format
20
from transformer_engine.jax import fp8_autocast
21
from transformer_engine.jax.quantize import (
22
    get_quantize_config,
23
    is_scaling_mode_supported,
24
25
    ScalingMode,
    update_collections,
26
    TensorSource,
27
)
28
from transformer_engine.jax.quantize.helper import _format2dtypes
29
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
30

31
32
33
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
34

35

36
class TestHelper(unittest.TestCase):
37

38
    @unittest.skipIf(not is_fp8_supported, reason=reason)
39
40
41
42
43
44
45
46
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
47
        updated_state = update_collections({"test1": updated_val}, original_state)
48
49
50
51
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
52
        updated_state = update_collections({"test1": updated_val}, original_state)
53
54
55
56
57
58
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

59
    def _check_default_state(self):
60
        self.assertFalse(get_quantize_config().is_fp8_enabled())
61

62
63
64
65
66
67
    def _compare_delay_scaling(self, test):
        self.assertEqual(get_quantize_config().MARGIN, test.margin)
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
        self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
        self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
68

69
    def _compare_current_scaling(self, test):
70
71
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
72
73
74
75
76
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source),
                ScalingMode.CURRENT_TENSOR_SCALING,
            )
77
78

    def _compare_mxfp8_scaling(self, test):
79
80
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
81
82
83
84
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
            )
85

86
87
88
89
90
91
92
93
94
95
96
97
98
    def _compare_nvfp4_scaling(self, test):
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
        for tensor_source in TensorSource:
            target_scaling_mode = (
                ScalingMode.NVFP4_2D_SCALING
                if tensor_source == TensorSource.KERNEL
                else ScalingMode.NVFP4_1D_SCALING
            )
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
            )

99
    @unittest.skipIf(not is_fp8_supported, reason=reason)
100
    def test_fp8_autocast_delayed_scaling(self):
101
        self._check_default_state()
102

103
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
104
            self._check_default_state()
105

106
        self._check_default_state()
107

108
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
109
        with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
110
            self.assertTrue(get_quantize_config().is_fp8_enabled())
111
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
112

113
        self._check_default_state()
114

115
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
116
        with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
117
            self.assertTrue(get_quantize_config().is_fp8_enabled())
118
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
119

120
        self._check_default_state()
121

122
    @unittest.skipIf(not is_fp8_supported, reason=reason)
123
    def test_fp8_autocast_current_scaling(self):
124
        self._check_default_state()
125

126
127
128
        with fp8_autocast(
            enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
        ):
129
            self._check_default_state()
130

131
        self._check_default_state()
132

133
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
134
        with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
135
            self.assertTrue(get_quantize_config().is_fp8_enabled())
136
137
            self._compare_current_scaling(cs)

138
        self._check_default_state()
139

140
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
141
        with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
142
            self.assertTrue(get_quantize_config().is_fp8_enabled())
143
144
            self._compare_current_scaling(cs)

145
        self._check_default_state()
146
147

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
148
    def test_fp8_autocast_mxfp8_block_scaling(self):
149
        self._check_default_state()
150

151
152
153
        with fp8_autocast(
            enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
        ):
154
            self._check_default_state()
155

156
        self._check_default_state()
157

158
        bs = MXFP8BlockScaling()
159
        with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
160
            self.assertTrue(get_quantize_config().is_fp8_enabled())
161
162
            self._compare_mxfp8_scaling(bs)

163
        self._check_default_state()
164

165
166
167
168
169
170
171
172
173
174
175
176
    @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
    def test_fp8_autocast_nvfp4_block_scaling(self):
        self._check_default_state()

        with fp8_autocast(
            enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()
        ):
            self._check_default_state()

        self._check_default_state()

        bs = NVFP4BlockScaling()
177
        with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
178
            self.assertTrue(get_quantize_config().is_fp8_enabled())
179
            self._compare_nvfp4_scaling(bs)
180

181
        self._check_default_state()