test_helper.py 5.54 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
14
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16
17
18
19
20
21
from transformer_engine.jax.quantize import (
    QuantizeConfig,
    is_fp8_available,
    ScalingMode,
    update_collections,
)
22
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
23

24
is_fp8_supported, reason = is_fp8_available()
25
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
26

27

28
class TestHelper(unittest.TestCase):
29

30
    @unittest.skipIf(not is_fp8_supported, reason=reason)
31
32
33
34
35
36
37
38
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
39
        updated_state = update_collections({"test1": updated_val}, original_state)
40
41
42
43
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
44
        updated_state = update_collections({"test1": updated_val}, original_state)
45
46
47
48
49
50
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

51
    def _check_default_state(self):
52
        self.assertFalse(QuantizeConfig.is_fp8_enabled())
53

Ming-Xu Huang's avatar
Ming-Xu Huang committed
54
55
56
57
58
59
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

60
61
62
63
64
65
66
67
68
    def _compare_current_scaling(self, test):
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)

    def _compare_mxfp8_scaling(self, test):
        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)

69
    @unittest.skipIf(not is_fp8_supported, reason=reason)
70
    def test_fp8_autocast_delayed_scaling(self):
71
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
72
        self._check_default_state()
73

74
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
75
            self._check_default_state()
76

77
        self._check_default_state()
78

79
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
80
        with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
81
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
82
83
            self._compare_delay_scaling(get_delayed_scaling(), ds)

84
        self._check_default_state()
85

86
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
87
        with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
88
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
89
90
            self._compare_delay_scaling(get_delayed_scaling(), ds)

91
        self._check_default_state()
92

93
    @unittest.skipIf(not is_fp8_supported, reason=reason)
94
    def test_fp8_autocast_current_scaling(self):
95
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
96
        self._check_default_state()
97

98
99
100
        with fp8_autocast(
            enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
        ):
101
            self._check_default_state()
102

103
        self._check_default_state()
104

105
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
106
        with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
107
108
109
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

110
        self._check_default_state()
111

112
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
113
        with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
114
115
116
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_current_scaling(cs)

117
        self._check_default_state()
118
119

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
120
    def test_fp8_autocast_mxfp8_block_scaling(self):
121
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
122
        self._check_default_state()
123

124
125
126
        with fp8_autocast(
            enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
        ):
127
            self._check_default_state()
128

129
        self._check_default_state()
130
131

        bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
132
        with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
133
134
135
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

136
        self._check_default_state()
137
138

        bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
139
        with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
140
141
142
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
            self._compare_mxfp8_scaling(bs)

143
        self._check_default_state()