test_helper.py 9.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import maps

13
from utils import assert_allclose
14
15
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
16
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
17
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
18
19
20
21
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource

22
23
is_fp8_supported, reason = is_fp8_available()

24
25
26

class TestFP8Helper(unittest.TestCase):

27
    @unittest.skipIf(not is_fp8_supported, reason=reason)
28
29
30
31
    def test_initialize(self):
        margin = 5.0
        fp8_format = FP8Format.E4M3
        update_fp8meta_interval = 10
32
        amax_history_len = 10
33
34
35
36

        FP8Helper.initialize(margin=margin,
                             fp8_format=fp8_format,
                             update_fp8meta_interval=update_fp8meta_interval,
37
                             amax_history_len=amax_history_len)
38
39
40
41
42
43
44
45
46
47
48
49
50

        self.assertEqual(
            FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
            f" but got {FP8Helper.MARGIN}.")
        self.assertEqual(
            FP8Helper.FP8_FORMAT, fp8_format,
            f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
            f" but got {FP8Helper.FP8_FORMAT}.")
        self.assertEqual(
            FP8Helper.UPDATE_FP8META_INTERVAL, update_fp8meta_interval,
            "FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
            f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
        self.assertEqual(
51
52
53
            FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
            f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
            f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
54
55
56

        FP8Helper.finalize()

57
    @unittest.skipIf(not is_fp8_supported, reason=reason)
58
    def test_update_fp8_metas(self):
59
        FP8Helper.initialize(margin=3.0, amax_history_len=3)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        seed = 0
        key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
        num_of_gemm = 10
        num_of_meta = FP8Helper.NUM_META_PER_GEMM * num_of_gemm

        def get_fp8_scale(fp8_max, amax, scale):
            fp8_max = np.array(fp8_max)
            amax = np.array(amax)
            scale = np.array(scale)

            exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN
            sf = np.round(np.power(2, np.abs(exp)))
            sf = np.where(amax > 0.0, sf, scale)
            sf = np.where(np.isfinite(amax), sf, scale)
            return np.where(exp < 0, 1 / sf, sf)

77
        meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
78
79
        fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
        fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
80
81
        fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1[:, 0:1],
                                         jnp.ones(meta_shape))
82
83
        fp8_scale_inv_array1 = 1 / fp8_scale_array1
        fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
84
85
        fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2[:, 0:1],
                                         jnp.ones(meta_shape))
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        fp8_scale_inv_array2 = 1 / fp8_scale_array2

        state = flax.core.frozen_dict.FrozenDict({
            FP8Helper.FP8_COLLECTION_NAME: {
                "test_update_fp8_metas1": {
                    FP8Helper.FP8_MAX_NAME: fp8_max_array,
                    FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
                    FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
                    FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
                },
                "test_update_fp8_metas2": {
                    FP8Helper.FP8_MAX_NAME: fp8_max_array,
                    FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
                    FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
                    FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
                }
            }
        })

        updated_state = FP8Helper.update_fp8_metas(state)

        state_array, _ = jax.tree_util.tree_flatten(updated_state)
        meta_per_gemm = FP8Helper.NUM_META_PER_GEMM + 1
        scale_shift = 2
        scale_inv_shift = 3
        assert_allclose(state_array[0 * meta_per_gemm + scale_shift], fp8_scale_array1)
        assert_allclose(state_array[0 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array1)
        assert_allclose(state_array[1 * meta_per_gemm + scale_shift], fp8_scale_array2)
        assert_allclose(state_array[1 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array2)

        FP8Helper.finalize()

118
    @unittest.skipIf(not is_fp8_supported, reason=reason)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    def test_generate_fp8_max_array(self):
        num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2

        def get_ref(format_for_test):
            refer_list = []
            for i in range(num_of_meta):
                val = format_for_test.value.max_bwd \
                    if i % FP8Helper.NUM_META_PER_GEMM == FP8Helper.GRAD_META_IDX_PER_GEMM \
                    else format_for_test.value.max_fwd
                refer_list.append([val])
            return jnp.asarray(refer_list)

        for fp8_format in FP8Format:
            FP8Helper.initialize(fp8_format=fp8_format)
            assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
            FP8Helper.finalize()

136
    @unittest.skipIf(not is_fp8_supported, reason=reason)
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
        updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
        updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

    def _check_defult_state(self):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
158
        self.assertFalse(FP8Helper.is_fp8_enabled())
159
160
        self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)

Ming-Xu Huang's avatar
Ming-Xu Huang committed
161
162
163
164
165
166
167
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.interval == test.interval)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

168
    @unittest.skipIf(not is_fp8_supported, reason=reason)
169
    def test_fp8_autocast(self):
170
171
172
        FP8Helper.finalize()    # Ensure the testing not affect by previous tests.
        self._check_defult_state()

173
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
174
175
            self.assertFalse(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
176
177
178
179
180

        self._check_defult_state()

        ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
181
182
183
            self.assertTrue(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), ds)

184
185
186
187
        self._check_defult_state()

        ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
188
189
190
            self.assertTrue(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), ds)

191
192
        self._check_defult_state()

193
    @unittest.skipIf(not is_fp8_supported, reason=reason)
194
    def test_fp8_autocast_with_sharding_resource(self):
195
        FP8Helper.finalize()    # Ensure the testing not affect by previous tests.
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        self._check_defult_state()

        ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)

        # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
        # srs = (
        #     (ShardingResource(None, None), MajorShardingType.SINGLE),
        #     (ShardingResource('dp', None), MajorShardingType.DP),
        #     (ShardingResource(None, 'tp'), MajorShardingType.TP),
        #     (ShardingResource('dp', 'tp'), MajorShardingType.DPTP),
        # )
        srs = (
            (ShardingResource(None, None), MajorShardingType.SINGLE),
            (ShardingResource('dp', None), MajorShardingType.SINGLE),
            (ShardingResource(None, 'tp'), MajorShardingType.SINGLE),
            (ShardingResource('dp', 'tp'), MajorShardingType.SINGLE),
        )
        # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
        with maps.Mesh(devices, ('dp', 'tp')):
            for sr, mst in srs:
                with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
219
220
                    self.assertTrue(FP8Helper.is_fp8_enabled())
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
221
222
223
                    self.assertEqual(infer_major_sharding_type(), mst)

                self._check_defult_state()