test_helper.py 6.32 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
14
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16
17
18
19
20
21
from transformer_engine.jax.quantize import (
    QuantizeConfig,
    is_fp8_available,
    ScalingMode,
    update_collections,
)
22
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
23

24
is_fp8_supported, reason = is_fp8_available()
25
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
26

27

28
class TestHelper(unittest.TestCase):
29

30
    @unittest.skipIf(not is_fp8_supported, reason=reason)
31
32
33
34
35
36
37
38
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
39
        updated_state = update_collections({"test1": updated_val}, original_state)
40
41
42
43
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
44
        updated_state = update_collections({"test1": updated_val}, original_state)
45
46
47
48
49
50
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

51
    def _check_default_state(self):
52
        self.assertFalse(QuantizeConfig.is_fp8_enabled())
53

Ming-Xu Huang's avatar
Ming-Xu Huang committed
54
55
56
57
58
59
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

60
61
62
63
64
65
66
67
68
    def _compare_current_scaling(self, test):
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)

    def _compare_mxfp8_scaling(self, test):
        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)

69
    @unittest.skipIf(not is_fp8_supported, reason=reason)
70
    def test_fp8_autocast_delayed_scaling(self):
71
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
72
        self._check_default_state()
73

74
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
75
            self._check_default_state()
76

77
        self._check_default_state()
78

79
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
80
        with fp8_autocast(enabled=True, fp8_recipe=ds):
81
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
82
83
            self._compare_delay_scaling(get_delayed_scaling(), ds)

84
        self._check_default_state()
85

86
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
87
        with fp8_autocast(enabled=True, fp8_recipe=ds):
88
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
89
90
            self._compare_delay_scaling(get_delayed_scaling(), ds)

91
        self._check_default_state()
92

93
    @unittest.skipIf(not is_fp8_supported, reason=reason)
94
    def test_fp8_autocast_current_scaling(self):
95
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
96
        self._check_default_state()
97
98

        with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
99
            self._check_default_state()
100

101
        self._check_default_state()
102

103
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
104
105
106
107
        with fp8_autocast(enabled=True, fp8_recipe=cs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

108
        self._check_default_state()
109

110
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
111
112
113
114
        with fp8_autocast(enabled=True, fp8_recipe=cs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

115
        self._check_default_state()
116
117

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
118
    def test_fp8_autocast_mxfp8_block_scaling(self):
119
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
120
        self._check_default_state()
121
122

        with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
123
            self._check_default_state()
124

125
        self._check_default_state()
126
127
128
129
130
131

        bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
        with fp8_autocast(enabled=True, fp8_recipe=bs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

132
        self._check_default_state()
133
134
135
136
137
138

        bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
        with fp8_autocast(enabled=True, fp8_recipe=bs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

139
        self._check_default_state()
140

141
    @unittest.skipIf(not is_fp8_supported, reason=reason)
142
    def test_fp8_autocast_with_sharding_resource(self):
143
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
144
        self._check_default_state()
145

146
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
147

148
149
        mesh_s = (
            (MeshResource(None, None)),
150
151
152
            (MeshResource("dp", None)),
            (MeshResource(None, "tp")),
            (MeshResource("dp", "tp")),
153
        )
hugo-syn's avatar
hugo-syn committed
154
        # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
155
156
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
157
        with jax.sharding.Mesh(devices, ("dp", "tp")):
158
159
            for sr in mesh_s:
                with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
160
                    self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
161
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
162
                    self.assertEqual(sr, global_mesh_resource())
163

164
                self._check_default_state()