test_helper.py 9.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import unittest

import flax
import jax
import jax.numpy as jnp
import numpy as np

12
from utils import assert_allclose
13
14
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
Ming-Xu Huang's avatar
Ming-Xu Huang committed
15
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
Frédéric Bastien's avatar
Frédéric Bastien committed
16
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo
17
18
19
20
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource

21
22
is_fp8_supported, reason = is_fp8_available()

23
24
25

class TestFP8Helper(unittest.TestCase):

26
    @unittest.skipIf(not is_fp8_supported, reason=reason)
27
28
29
30
    def test_initialize(self):
        margin = 5.0
        fp8_format = FP8Format.E4M3
        update_fp8meta_interval = 10
31
        amax_history_len = 10
32
33
34
35

        FP8Helper.initialize(margin=margin,
                             fp8_format=fp8_format,
                             update_fp8meta_interval=update_fp8meta_interval,
36
                             amax_history_len=amax_history_len)
37
38
39
40
41
42
43
44
45
46
47
48
49

        self.assertEqual(
            FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
            f" but got {FP8Helper.MARGIN}.")
        self.assertEqual(
            FP8Helper.FP8_FORMAT, fp8_format,
            f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
            f" but got {FP8Helper.FP8_FORMAT}.")
        self.assertEqual(
            FP8Helper.UPDATE_FP8META_INTERVAL, update_fp8meta_interval,
            "FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
            f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
        self.assertEqual(
50
51
52
            FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
            f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
            f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
53
54
55

        FP8Helper.finalize()

56
    @unittest.skipIf(not is_fp8_supported, reason=reason)
57
    def test_update_fp8_metas(self):
58
        FP8Helper.initialize(margin=3.0, amax_history_len=3)
59
60
61
62
63
64

        seed = 0
        key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
        num_of_gemm = 10
        num_of_meta = FP8Helper.NUM_META_PER_GEMM * num_of_gemm

Frédéric Bastien's avatar
Frédéric Bastien committed
65
66
        def select_amax(amaxes):
            if FP8Helper.AMAX_COMPUTE_ALGO == AmaxComputeAlgo.MAX:
67
                return jnp.max(amaxes, axis=-1, keepdims=True)
Frédéric Bastien's avatar
Frédéric Bastien committed
68
69
            return amaxes[:, 0:1]

70
71
72
73
74
75
76
77
78
79
80
        def get_fp8_scale(fp8_max, amax, scale):
            fp8_max = np.array(fp8_max)
            amax = np.array(amax)
            scale = np.array(scale)

            exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN
            sf = np.round(np.power(2, np.abs(exp)))
            sf = np.where(amax > 0.0, sf, scale)
            sf = np.where(np.isfinite(amax), sf, scale)
            return np.where(exp < 0, 1 / sf, sf)

81
82
        amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
        scale_meta_shape = (num_of_meta, 1)
83
        fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
84
        fp8_amax_array1 = jax.random.uniform(key1, shape=amax_meta_shape)
Frédéric Bastien's avatar
Frédéric Bastien committed
85
        fp8_scale_array1 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array1),
86
                                         jnp.ones(scale_meta_shape))
87
        fp8_scale_inv_array1 = 1 / fp8_scale_array1
88
        fp8_amax_array2 = jax.random.uniform(key2, shape=amax_meta_shape)
Frédéric Bastien's avatar
Frédéric Bastien committed
89
        fp8_scale_array2 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array2),
90
                                         jnp.ones(scale_meta_shape))
91
92
93
94
95
96
97
        fp8_scale_inv_array2 = 1 / fp8_scale_array2

        state = flax.core.frozen_dict.FrozenDict({
            FP8Helper.FP8_COLLECTION_NAME: {
                "test_update_fp8_metas1": {
                    FP8Helper.FP8_MAX_NAME: fp8_max_array,
                    FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
98
99
                    FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
                    FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
100
101
102
103
                },
                "test_update_fp8_metas2": {
                    FP8Helper.FP8_MAX_NAME: fp8_max_array,
                    FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
104
105
                    FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
                    FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                }
            }
        })

        updated_state = FP8Helper.update_fp8_metas(state)

        state_array, _ = jax.tree_util.tree_flatten(updated_state)
        meta_per_gemm = FP8Helper.NUM_META_PER_GEMM + 1
        scale_shift = 2
        scale_inv_shift = 3
        assert_allclose(state_array[0 * meta_per_gemm + scale_shift], fp8_scale_array1)
        assert_allclose(state_array[0 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array1)
        assert_allclose(state_array[1 * meta_per_gemm + scale_shift], fp8_scale_array2)
        assert_allclose(state_array[1 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array2)

        FP8Helper.finalize()

123
    @unittest.skipIf(not is_fp8_supported, reason=reason)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    def test_generate_fp8_max_array(self):
        num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2

        def get_ref(format_for_test):
            refer_list = []
            for i in range(num_of_meta):
                val = format_for_test.value.max_bwd \
                    if i % FP8Helper.NUM_META_PER_GEMM == FP8Helper.GRAD_META_IDX_PER_GEMM \
                    else format_for_test.value.max_fwd
                refer_list.append([val])
            return jnp.asarray(refer_list)

        for fp8_format in FP8Format:
            FP8Helper.initialize(fp8_format=fp8_format)
            assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
            FP8Helper.finalize()

141
    @unittest.skipIf(not is_fp8_supported, reason=reason)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    def test_update_collections(self):
        original_val = 0.0
        updated_val = 10.0

        original_state = {
            "test1": original_val,
            "test2": original_val,
        }
        updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)

        original_state = flax.core.frozen_dict.FrozenDict(original_state)
        updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
        self.assertEqual(updated_state["test1"], updated_val)
        self.assertEqual(updated_state["test2"], original_val)


class TestFP8Functions(unittest.TestCase):

    def _check_defult_state(self):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
163
        self.assertFalse(FP8Helper.is_fp8_enabled())
164
165
        self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)

Ming-Xu Huang's avatar
Ming-Xu Huang committed
166
167
168
169
170
171
172
    def _compare_delay_scaling(self, ref, test):
        self.assertTrue(ref.margin == test.margin)
        self.assertTrue(ref.interval == test.interval)
        self.assertTrue(ref.fp8_format == test.fp8_format)
        self.assertTrue(ref.amax_history_len == test.amax_history_len)
        self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

173
    @unittest.skipIf(not is_fp8_supported, reason=reason)
174
    def test_fp8_autocast(self):
175
176
177
        FP8Helper.finalize()    # Ensure the testing not affect by previous tests.
        self._check_defult_state()

178
        with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
179
180
            self.assertFalse(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
181
182
183
184
185

        self._check_defult_state()

        ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
186
187
188
            self.assertTrue(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), ds)

189
190
191
192
        self._check_defult_state()

        ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
193
194
195
            self.assertTrue(FP8Helper.is_fp8_enabled())
            self._compare_delay_scaling(get_delayed_scaling(), ds)

196
197
        self._check_defult_state()

198
    @unittest.skipIf(not is_fp8_supported, reason=reason)
199
    def test_fp8_autocast_with_sharding_resource(self):
200
        FP8Helper.finalize()    # Ensure the testing not affect by previous tests.
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        self._check_defult_state()

        ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)

        # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
        # srs = (
        #     (ShardingResource(None, None), MajorShardingType.SINGLE),
        #     (ShardingResource('dp', None), MajorShardingType.DP),
        #     (ShardingResource(None, 'tp'), MajorShardingType.TP),
        #     (ShardingResource('dp', 'tp'), MajorShardingType.DPTP),
        # )
        srs = (
            (ShardingResource(None, None), MajorShardingType.SINGLE),
            (ShardingResource('dp', None), MajorShardingType.SINGLE),
            (ShardingResource(None, 'tp'), MajorShardingType.SINGLE),
            (ShardingResource('dp', 'tp'), MajorShardingType.SINGLE),
        )
        # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
        mesh_shape = (1, 1)
        devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
221
        with jax.sharding.Mesh(devices, ('dp', 'tp')):
222
223
            for sr, mst in srs:
                with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
Ming-Xu Huang's avatar
Ming-Xu Huang committed
224
225
                    self.assertTrue(FP8Helper.is_fp8_enabled())
                    self._compare_delay_scaling(get_delayed_scaling(), ds)
226
227
228
                    self.assertEqual(infer_major_sharding_type(), mst)

                self._check_defult_state()