test_helper.py 5.07 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
14
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
17
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
18

19
20
is_fp8_supported, reason = is_fp8_available()

21

22
class TestQuantizeConfig(unittest.TestCase):
23

24
    @unittest.skipIf(not is_fp8_supported, reason=reason)
25
26
27
    def test_initialize(self):
        margin = 5.0
        fp8_format = FP8Format.E4M3
28
        amax_history_len = 10
29

30
        QuantizeConfig.initialize(
31
32
            margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
        )
33
34

        self.assertEqual(
35
            QuantizeConfig.MARGIN,
36
            margin,
37
38
            f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
            f" but got {QuantizeConfig.MARGIN}.",
39
        )
40
        self.assertEqual(
41
            QuantizeConfig.FP8_FORMAT,
42
            fp8_format,
43
44
            f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
            f" but got {QuantizeConfig.FP8_FORMAT}.",
45
        )
46
        self.assertEqual(
47
            QuantizeConfig.AMAX_HISTORY_LEN,
48
            amax_history_len,
49
50
            f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
            f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
51
        )
52

53
        QuantizeConfig.finalize()
54

55
    @unittest.skipIf(not is_fp8_supported, reason=reason)
56
57
58
59
60
61
62
63
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
64
        updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
65
66
67
68
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
69
        updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
70
71
72
73
74
75
76
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

    def _check_defult_state(self):
77
        self.assertFalse(QuantizeConfig.is_fp8_enabled())
78

Ming-Xu Huang's avatar
Ming-Xu Huang committed
79
80
81
82
83
84
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

85
    @unittest.skipIf(not is_fp8_supported, reason=reason)
86
    def test_fp8_autocast(self):
87
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
88
89
        self._check_defult_state()

90
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
91
            self.assertFalse(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
92
            self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
93
94
95

        self._check_defult_state()

96
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
97
        with fp8_autocast(enabled=True, fp8_recipe=ds):
98
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
99
100
            self._compare_delay_scaling(get_delayed_scaling(), ds)

101
102
        self._check_defult_state()

103
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
104
        with fp8_autocast(enabled=True, fp8_recipe=ds):
105
            self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
106
107
            self._compare_delay_scaling(get_delayed_scaling(), ds)

108
109
        self._check_defult_state()

110
    @unittest.skipIf(not is_fp8_supported, reason=reason)
111
    def test_fp8_autocast_with_sharding_resource(self):
112
        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
113
114
        self._check_defult_state()

115
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
116

117
118
        mesh_s = (
            (MeshResource(None, None)),
119
120
121
            (MeshResource("dp", None)),
            (MeshResource(None, "tp")),
            (MeshResource("dp", "tp")),
122
        )
hugo-syn's avatar
hugo-syn committed
123
        # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
124
125
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
126
        with jax.sharding.Mesh(devices, ("dp", "tp")):
127
128
            for sr in mesh_s:
                with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
129
                    self.assertTrue(QuantizeConfig.is_fp8_enabled())
Ming-Xu Huang's avatar
Ming-Xu Huang committed
130
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
131
                    self.assertEqual(sr, global_mesh_resource())
132
133

                self._check_defult_state()