test_helper.py 6.82 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
14
15
16
17
18
from transformer_engine.common.recipe import (
    DelayedScaling,
    MXFP8BlockScaling,
    Float8CurrentScaling,
    NVFP4BlockScaling,
)
19
from transformer_engine.common.recipe import Format as FP8Format
20
from transformer_engine.jax import autocast
21
from transformer_engine.jax.quantize import (
22
    get_quantize_config,
23
    is_scaling_mode_supported,
24
25
    ScalingMode,
    update_collections,
26
    TensorSource,
27
)
28
from transformer_engine.jax.quantize.helper import _format2dtypes
29
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
30

31
32
33
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
34

35

36
class TestHelper(unittest.TestCase):
37

38
    @unittest.skipIf(not is_fp8_supported, reason=reason)
39
40
41
42
43
44
45
46
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
47
        updated_state = update_collections({"test1": updated_val}, original_state)
48
49
50
51
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
52
        updated_state = update_collections({"test1": updated_val}, original_state)
53
54
55
56
57
58
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

59
    def _check_default_state(self):
60
        self.assertFalse(get_quantize_config().is_fp8_enabled())
61

62
63
64
65
66
67
    def _compare_delay_scaling(self, test):
        self.assertEqual(get_quantize_config().MARGIN, test.margin)
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
        self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
        self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
68

69
    def _compare_current_scaling(self, test):
70
71
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
72
73
74
75
76
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source),
                ScalingMode.CURRENT_TENSOR_SCALING,
            )
77
78

    def _compare_mxfp8_scaling(self, test):
79
80
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
81
82
83
84
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
            )
85

86
87
88
89
90
91
92
93
94
95
96
97
98
    def _compare_nvfp4_scaling(self, test):
        self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
        self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
        for tensor_source in TensorSource:
            target_scaling_mode = (
                ScalingMode.NVFP4_2D_SCALING
                if tensor_source == TensorSource.KERNEL
                else ScalingMode.NVFP4_1D_SCALING
            )
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
            )

99
    @unittest.skipIf(not is_fp8_supported, reason=reason)
100
    def test_autocast_delayed_scaling(self):
101
        self._check_default_state()
102

103
        with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
104
            self._check_default_state()
105

106
        self._check_default_state()
107

108
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
109
        with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
110
            self.assertTrue(get_quantize_config().is_fp8_enabled())
111
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
112

113
        self._check_default_state()
114

115
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
116
        with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
117
            self.assertTrue(get_quantize_config().is_fp8_enabled())
118
            self._compare_delay_scaling(ds)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
119

120
        self._check_default_state()
121

122
    @unittest.skipIf(not is_fp8_supported, reason=reason)
123
    def test_autocast_current_scaling(self):
124
        self._check_default_state()
125

126
        with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
127
            self._check_default_state()
128

129
        self._check_default_state()
130

131
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
132
        with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
133
            self.assertTrue(get_quantize_config().is_fp8_enabled())
134
135
            self._compare_current_scaling(cs)

136
        self._check_default_state()
137

138
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
139
        with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
140
            self.assertTrue(get_quantize_config().is_fp8_enabled())
141
142
            self._compare_current_scaling(cs)

143
        self._check_default_state()
144
145

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
146
    def test_autocast_mxfp8_block_scaling(self):
147
        self._check_default_state()
148

149
        with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
150
            self._check_default_state()
151

152
        self._check_default_state()
153

154
        bs = MXFP8BlockScaling()
155
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
156
            self.assertTrue(get_quantize_config().is_fp8_enabled())
157
158
            self._compare_mxfp8_scaling(bs)

159
        self._check_default_state()
160

161
    @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
162
    def test_autocast_nvfp4_block_scaling(self):
163
164
        self._check_default_state()

165
        with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
166
167
168
169
170
            self._check_default_state()

        self._check_default_state()

        bs = NVFP4BlockScaling()
171
        with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
172
            self.assertTrue(get_quantize_config().is_fp8_enabled())
173
            self._compare_nvfp4_scaling(bs)
174

175
        self._check_default_state()