test_helper.py 4.96 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
14
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
Frédéric Bastien's avatar
Frédéric Bastien committed
16
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo
17
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
18

19
20
is_fp8_supported, reason = is_fp8_available()

21
22
23

class TestFP8Helper(unittest.TestCase):

24
    @unittest.skipIf(not is_fp8_supported, reason=reason)
25
26
27
    def test_initialize(self):
        margin = 5.0
        fp8_format = FP8Format.E4M3
28
        amax_history_len = 10
29

30
31
32
        FP8Helper.initialize(
            margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
        )
33
34

        self.assertEqual(
35
36
37
38
39
            FP8Helper.MARGIN,
            margin,
            f"FP8Helper.MARGIN initialization failed, should be {margin}"
            f" but got {FP8Helper.MARGIN}.",
        )
40
        self.assertEqual(
41
42
            FP8Helper.FP8_FORMAT,
            fp8_format,
43
            f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
44
45
            f" but got {FP8Helper.FP8_FORMAT}.",
        )
46
        self.assertEqual(
47
48
            FP8Helper.AMAX_HISTORY_LEN,
            amax_history_len,
49
            f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
50
51
            f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
        )
52
53
54

        FP8Helper.finalize()

55
    @unittest.skipIf(not is_fp8_supported, reason=reason)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
        updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
        updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

    def _check_defult_state(self):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
77
        self.assertFalse(FP8Helper.is_fp8_enabled())
78

Ming-Xu Huang's avatar
Ming-Xu Huang committed
79
80
81
82
83
84
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

85
    @unittest.skipIf(not is_fp8_supported, reason=reason)
86
    def test_fp8_autocast(self):
87
        FP8Helper.finalize()  # Ensure the testing not affect by previous tests.
88
89
        self._check_defult_state()

90
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
91
92
            self.assertFalse(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
93
94
95

        self._check_defult_state()

96
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
97
        with fp8_autocast(enabled=True, fp8_recipe=ds):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
98
99
100
            self.assertTrue(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), ds)

101
102
        self._check_defult_state()

103
        ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
104
        with fp8_autocast(enabled=True, fp8_recipe=ds):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
105
106
107
            self.assertTrue(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), ds)

108
109
        self._check_defult_state()

110
    @unittest.skipIf(not is_fp8_supported, reason=reason)
111
    def test_fp8_autocast_with_sharding_resource(self):
112
        FP8Helper.finalize()  # Ensure the testing not affect by previous tests.
113
114
        self._check_defult_state()

115
        ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
116

117
118
        mesh_s = (
            (MeshResource(None, None)),
119
120
121
            (MeshResource("dp", None)),
            (MeshResource(None, "tp")),
            (MeshResource("dp", "tp")),
122
        )
hugo-syn's avatar
hugo-syn committed
123
        # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
124
125
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
126
        with jax.sharding.Mesh(devices, ("dp", "tp")):
127
128
            for sr in mesh_s:
                with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
129
130
                    self.assertTrue(FP8Helper.is_fp8_enabled())
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
131
                    self.assertEqual(sr, global_mesh_resource())
132
133

                self._check_defult_state()