test_helper.py 6.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
14
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16
17
18
19
20
21
from transformer_engine.jax.quantize import (
    QuantizeConfig,
    is_fp8_available,
    ScalingMode,
    update_collections,
)
22
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
23

24
is_fp8_supported, reason = is_fp8_available()
25
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
26

27

28
class TestHelper(unittest.TestCase):
29

30
    @unittest.skipIf(not is_fp8_supported, reason=reason)
31
32
33
34
35
36
37
38
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
39
        updated_state = update_collections({"test1": updated_val}, original_state)
40
41
42
43
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
44
        updated_state = update_collections({"test1": updated_val}, original_state)
45
46
47
48
49
50
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

51
    def _check_default_state(self):
52
        self.assertFalse(QuantizeConfig.is_fp8_enabled())
53

Ming-Xu Huang's avatar
Ming-Xu Huang committed
54
55
56
57
58
59
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

60
61
62
63
64
65
66
67
68
69
    def _compare_current_scaling(self, test):
        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)

    def _compare_mxfp8_scaling(self, test):
        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)

70
    @unittest.skipIf(not is_fp8_supported, reason=reason)
71
    def test_fp8_autocast_delayed_scaling(self):
72
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
73
        self._check_default_state()
74

75
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
76
            self._check_default_state()
77

78
        self._check_default_state()
79

80
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
81
        with fp8_autocast(enabled=True, fp8_recipe=ds):
82
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
83
84
            self._compare_delay_scaling(get_delayed_scaling(), ds)

85
        self._check_default_state()
86

87
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
88
        with fp8_autocast(enabled=True, fp8_recipe=ds):
89
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
90
91
            self._compare_delay_scaling(get_delayed_scaling(), ds)

92
        self._check_default_state()
93

94
95
96
    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
    def test_fp8_autocast_mxfp8_scaling(self):
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
97
        self._check_default_state()
98
99

        with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
100
            self._check_default_state()
101

102
        self._check_default_state()
103
104
105
106
107
108

        cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
        with fp8_autocast(enabled=True, fp8_recipe=cs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

109
        self._check_default_state()
110
111
112
113
114
115

        cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
        with fp8_autocast(enabled=True, fp8_recipe=cs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

116
        self._check_default_state()
117
118
119
120

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
    def test_fp8_autocast_mxfp8_scaling(self):
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
121
        self._check_default_state()
122
123

        with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
124
            self._check_default_state()
125

126
        self._check_default_state()
127
128
129
130
131
132

        bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
        with fp8_autocast(enabled=True, fp8_recipe=bs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

133
        self._check_default_state()
134
135
136
137
138
139

        bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
        with fp8_autocast(enabled=True, fp8_recipe=bs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

140
        self._check_default_state()
141

142
    @unittest.skipIf(not is_fp8_supported, reason=reason)
143
    def test_fp8_autocast_with_sharding_resource(self):
144
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
145
        self._check_default_state()
146

147
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
148

149
150
        mesh_s = (
            (MeshResource(None, None)),
151
152
153
            (MeshResource("dp", None)),
            (MeshResource(None, "tp")),
            (MeshResource("dp", "tp")),
154
        )
hugo-syn's avatar
hugo-syn committed
155
        # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
156
157
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
158
        with jax.sharding.Mesh(devices, ("dp", "tp")):
159
160
            for sr in mesh_s:
                with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
161
                    self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
162
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
163
                    self.assertEqual(sr, global_mesh_resource())
164

165
                self._check_default_state()