test_helper.py 6.66 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
14
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16
17
18
19
20
21
from transformer_engine.jax.quantize import (
    QuantizeConfig,
    is_fp8_available,
    ScalingMode,
    update_collections,
)
22
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
23

24
is_fp8_supported, reason = is_fp8_available()
25
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
26

27

28
class TestHelper(unittest.TestCase):
29

30
    @unittest.skipIf(not is_fp8_supported, reason=reason)
31
32
33
34
35
36
37
38
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
39
        updated_state = update_collections({"test1": updated_val}, original_state)
40
41
42
43
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
44
        updated_state = update_collections({"test1": updated_val}, original_state)
45
46
47
48
49
50
51
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

    def _check_defult_state(self):
52
        self.assertFalse(QuantizeConfig.is_fp8_enabled())
53

Ming-Xu Huang's avatar
Ming-Xu Huang committed
54
55
56
57
58
59
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

60
61
62
63
64
65
66
67
68
69
    def _compare_current_scaling(self, test):
        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)

    def _compare_mxfp8_scaling(self, test):
        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)

70
    @unittest.skipIf(not is_fp8_supported, reason=reason)
71
    def test_fp8_autocast_delayed_scaling(self):
72
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
73
74
        self._check_defult_state()

75
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
76
            self.assertFalse(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
77
            self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
78
79
80

        self._check_defult_state()

81
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
82
        with fp8_autocast(enabled=True, fp8_recipe=ds):
83
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
84
85
            self._compare_delay_scaling(get_delayed_scaling(), ds)

86
87
        self._check_defult_state()

88
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
89
        with fp8_autocast(enabled=True, fp8_recipe=ds):
90
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
91
92
            self._compare_delay_scaling(get_delayed_scaling(), ds)

93
94
        self._check_defult_state()

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
    def test_fp8_autocast_mxfp8_scaling(self):
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
        self._check_defult_state()

        with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
            self.assertFalse(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(Float8CurrentScaling())

        self._check_defult_state()

        cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
        with fp8_autocast(enabled=True, fp8_recipe=cs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

        self._check_defult_state()

        cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
        with fp8_autocast(enabled=True, fp8_recipe=cs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

        self._check_defult_state()

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
    def test_fp8_autocast_mxfp8_scaling(self):
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
        self._check_defult_state()

        with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
            self.assertFalse(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(MXFP8BlockScaling())

        self._check_defult_state()

        bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
        with fp8_autocast(enabled=True, fp8_recipe=bs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

        self._check_defult_state()

        bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
        with fp8_autocast(enabled=True, fp8_recipe=bs):
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

        self._check_defult_state()

145
    @unittest.skipIf(not is_fp8_supported, reason=reason)
146
    def test_fp8_autocast_with_sharding_resource(self):
147
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
148
149
        self._check_defult_state()

150
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
151

152
153
        mesh_s = (
            (MeshResource(None, None)),
154
155
156
            (MeshResource("dp", None)),
            (MeshResource(None, "tp")),
            (MeshResource("dp", "tp")),
157
        )
hugo-syn's avatar
hugo-syn committed
158
        # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
159
160
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
161
        with jax.sharding.Mesh(devices, ("dp", "tp")):
162
163
            for sr in mesh_s:
                with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
164
                    self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
165
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
166
                    self.assertEqual(sr, global_mesh_resource())
167
168

                self._check_defult_state()