test_helper.py 5.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
14
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16
from transformer_engine.jax.quantize import (
17
    get_quantize_config,
18
19
20
    is_fp8_available,
    ScalingMode,
    update_collections,
21
    TensorSource,
22
)
23
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
24

25
is_fp8_supported, reason = is_fp8_available()
26
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
27

28

29
class TestHelper(unittest.TestCase):
30

31
    @unittest.skipIf(not is_fp8_supported, reason=reason)
32
33
34
35
36
37
38
39
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
40
        updated_state = update_collections({"test1": updated_val}, original_state)
41
42
43
44
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
45
        updated_state = update_collections({"test1": updated_val}, original_state)
46
47
48
49
50
51
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

52
    def _check_default_state(self):
53
        self.assertFalse(get_quantize_config().is_fp8_enabled())
54

Ming-Xu Huang's avatar
Ming-Xu Huang committed
55
56
57
58
59
60
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

61
    def _compare_current_scaling(self, test):
62
63
64
65
66
67
        self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source),
                ScalingMode.CURRENT_TENSOR_SCALING,
            )
68
69

    def _compare_mxfp8_scaling(self, test):
70
71
72
73
74
75
        self.assertEqual(get_quantize_config().MARGIN, test.margin)
        self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
        for tensor_source in TensorSource:
            self.assertEqual(
                get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
            )
76

77
    @unittest.skipIf(not is_fp8_supported, reason=reason)
78
    def test_fp8_autocast_delayed_scaling(self):
79
        self._check_default_state()
80

81
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
82
            self._check_default_state()
83

84
        self._check_default_state()
85

86
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
87
        with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
88
            self.assertTrue(get_quantize_config().is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
89
90
            self._compare_delay_scaling(get_delayed_scaling(), ds)

91
        self._check_default_state()
92

93
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
94
        with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
95
            self.assertTrue(get_quantize_config().is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
96
97
            self._compare_delay_scaling(get_delayed_scaling(), ds)

98
        self._check_default_state()
99

100
    @unittest.skipIf(not is_fp8_supported, reason=reason)
101
    def test_fp8_autocast_current_scaling(self):
102
        self._check_default_state()
103

104
105
106
        with fp8_autocast(
            enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
        ):
107
            self._check_default_state()
108

109
        self._check_default_state()
110

111
        cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
112
        with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
113
            self.assertTrue(get_quantize_config().is_fp8_enabled())
114
115
            self._compare_current_scaling(cs)

116
        self._check_default_state()
117

118
        cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
119
        with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
120
            self.assertTrue(get_quantize_config().is_fp8_enabled())
121
122
            self._compare_current_scaling(cs)

123
        self._check_default_state()
124
125

    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
126
    def test_fp8_autocast_mxfp8_block_scaling(self):
127
        self._check_default_state()
128

129
130
131
        with fp8_autocast(
            enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
        ):
132
            self._check_default_state()
133

134
        self._check_default_state()
135
136

        bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
137
        with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
138
            self.assertTrue(get_quantize_config().is_fp8_enabled())
139
140
            self._compare_mxfp8_scaling(bs)

141
        self._check_default_state()
142
143

        bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
144
        with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
145
            self.assertTrue(get_quantize_config().is_fp8_enabled())
146
147
            self._compare_mxfp8_scaling(bs)

148
        self._check_default_state()