fp8.py 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
import os
from contextlib import contextmanager
from typing import Optional, Union, Dict, List, Tuple
from flax.core.frozen_dict import FrozenDict
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource

Collection = Union[Dict, FrozenDict]


def _format2dtypes(format_: Format):
    if format_ == Format.E4M3:
        return DType.kFloat8E4M3, DType.kFloat8E4M3
    if format_ == Format.E5M2:
        return DType.kFloat8E5M2, DType.kFloat8E5M2
    if format_ == Format.HYBRID:
        return DType.kFloat8E4M3, DType.kFloat8E5M2
    return DType.kBFloat16, DType.kBFloat16


class FP8GemmPackage:
    """
    A container that contains all required data for
    FP8 GEMM
    """

    def __init__(
        self,
        num_of_gemm: int,
        inputs: jnp.ndarray,
        kernels: List[jnp.ndarray],
        fp8_max: jnp.ndarray,
        amax: jnp.ndarray,
        scale: jnp.ndarray,
        scale_inv: jnp.ndarray,
    ) -> None:
        self._num_of_gemm = num_of_gemm
        self._inputs = inputs

        assert len(kernels) == self._num_of_gemm
        self._kernels = kernels

        total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
        assert fp8_max.shape[0] == total_num_of_meta
        self._fp8_max = fp8_max
        assert amax.shape[0] == total_num_of_meta
        self._amax = amax
        assert scale.shape[0] == total_num_of_meta
        self._scale = scale
        assert scale_inv.shape[0] == total_num_of_meta
        self._scale_inv = scale_inv

    @property
    def num_of_gemm(self) -> int:
        """
        num_of_gemm of this package
        """
        return self._num_of_gemm

    @property
    def inputs(self) -> jnp.ndarray:
        """
        inputs of this package
        """
        return self._inputs

    @property
    def kernels(self) -> List[jnp.ndarray]:
        """
        kernels of this package
        """
        return self._kernels

    @property
    def fp8_max(self) -> jnp.ndarray:
        """
        fp8_max of this package
        """
        return self._fp8_max

    @property
    def amax(self) -> jnp.ndarray:
        """
        amax of this package
        """
        return self._amax

    @property
    def scale(self) -> jnp.ndarray:
        """
        scale of this package
        """
        return self._scale

    @property
    def scale_inv(self) -> jnp.ndarray:
        """
        scale_inv of this package
        """
        return self._scale_inv


class FP8Helper:
    """
    FP8 helper to manage the FP8 meta
    """
    INITIALIZED = False
    MARGIN: float = 0.0
    FP8_FORMAT: Format = Format.HYBRID
    FWD_DTYPE: DType = DType.kFloat8E4M3
    BWD_DTYPE: DType = DType.kFloat8E5M2
    UPDATE_FP8META_INTERVAL: int = 1
    AMAX_HISTORY_SIZE: int = 1
    NUM_META_PER_GEMM: int = 3
    INPUT_META_IDX_PER_GEMM: int = 0
    KERNEL_META_IDX_PER_GEMM: int = 1
    GRAD_META_IDX_PER_GEMM: int = 2
    FP8_COLLECTION_NAME: str = "fp8_meta_collection"
    FP8_AMAX_NAME: str = "fp8_meta_amax"
    FP8_SCALE_NAME: str = "fp8_meta_scale"
    FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
    FP8_MAX_NAME: str = "fp8_max"
    FP8_2X_ACC_FPROP_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_FPROP"
    FP8_2X_ACC_DGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_DGRAD"
    FP8_2X_ACC_WGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_WGRAD"
    FP8_2X_ACC_FPROP: bool = False
    FP8_2X_ACC_DGRAD: bool = False
    FP8_2X_ACC_WGRAD: bool = False

    @staticmethod
    def enable_fp8():
        """
        Indicate if fp8 training is enable or not.
        """
        return FP8Helper.INITIALIZED

    @staticmethod
    def initialize(margin: float = 0.0,
                   fp8_format: Format = Format.HYBRID,
                   update_fp8meta_interval: int = 1,
                   amax_history_size: int = 1) -> None:
        """
        Initialize the FP8 meta
        """
        FP8Helper.INITIALIZED = True
        FP8Helper.MARGIN = margin
        FP8Helper.FP8_FORMAT = fp8_format
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
            _format2dtypes(FP8Helper.FP8_FORMAT)
        FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
        FP8Helper.AMAX_HISTORY_SIZE = amax_history_size
        FP8Helper.FP8_2X_ACC_FPROP = bool(
            int(os.environ.get(FP8Helper.FP8_2X_ACC_FPROP_ENV_VAR_NAME, False)))
        FP8Helper.FP8_2X_ACC_DGRAD = bool(
            int(os.environ.get(FP8Helper.FP8_2X_ACC_DGRAD_ENV_VAR_NAME, False)))
        FP8Helper.FP8_2X_ACC_WGRAD = bool(
            int(os.environ.get(FP8Helper.FP8_2X_ACC_WGRAD_ENV_VAR_NAME, False)))

    @staticmethod
    def finalize() -> None:
        """
        FP8 helper finalize
        """
        FP8Helper.INITIALIZED = False
        FP8Helper.MARGIN = 0.0
        FP8Helper.FP8_FORMAT = Format.HYBRID
        FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
        FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
        FP8Helper.UPDATE_FP8META_INTERVAL = 1
        FP8Helper.AMAX_HISTORY_SIZE = 1

    @staticmethod
    def update_collections(new: Collection, original: Collection) -> None:
        """
        Update the collections
        """
        if not isinstance(original, FrozenDict):
            original = FrozenDict(original)
        for key in new:
            if key in original:
                original, _ = original.pop(key)
        return FrozenDict({**new, **original})

    @staticmethod
    def update_fp8_metas(state: Collection) -> Collection:
        """
        Update the FP8 metas
        """
        if FP8Helper.FP8_COLLECTION_NAME in state:
            if not isinstance(state, FrozenDict):
                state = FrozenDict(state)
            others, fp8_metas = state.pop(FP8Helper.FP8_COLLECTION_NAME)
            fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
            return FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})
        return state

    @staticmethod
    def generate_fp8_max_array(num_of_meta):
        """
        Generate the FP8 max array
        """
        num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
        fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd
        fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd
        fp8_max_per_gemm = []
        for i in range(FP8Helper.NUM_META_PER_GEMM):
            val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
                else fp8_max_fwd
            fp8_max_per_gemm.append([val])
        fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
        return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)

    @staticmethod
    def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int]:
        """
        Obtain the index about FP8 metas by the given GEMM index.
        """
        input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
        kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
        grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
        return input_idx, kernel_idx, grad_idx

    @staticmethod
    @jax.jit
    def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
        fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
        num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
        num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
        for i in range(num_of_gemm):
            # flattern array is ordered in alphabetical order of collection names
            fp8_max_idx = i * num_of_meta_with_max
            fp8_amax_idx = fp8_max_idx + 1
            fp8_scale_idx = fp8_amax_idx + 1
            fp8_scale_inv_idx = fp8_scale_idx + 1

            fp8_max = fp8_meta_arrays[fp8_max_idx]
            amax = fp8_meta_arrays[fp8_amax_idx]
            scale = fp8_meta_arrays[fp8_scale_idx]

            exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
            sf = jnp.round(jnp.power(2, jnp.abs(exp)))
            sf = jnp.where(amax > 0.0, sf, scale)
            sf = jnp.where(jnp.isfinite(amax), sf, scale)
            scale = jnp.where(exp < 0, 1 / sf, sf)
            fp8_meta_arrays[fp8_scale_idx] = scale
            fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale

        return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)


@contextmanager
def fp8_autocast(enabled: bool = False,
                 fp8_recipe: Optional[DelayedScaling] = None,
                 sharding_resource: Optional[ShardingResource] = None) -> None:
    """
    Context manager for FP8 usage.

    .. code-block:: python

        mesh_shape = (4, 2)
        dp_mesh_axis_name = 'data_parallel'
        tp_mesh_axis_name = 'tensor_parallel'
        devices = np.asarray(jax.devices()).reshape(*mesh_shape)

        with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
            sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name)

            with fp8_autocast(enabled=True, sharding_resource=sharding_resource):
                rules = extend_logical_axis_rules(tuple())
                transformer = TransformerLayer()

                with partitioning.axis_rules(rules):
                    pjit(transformer.init, ...)(...)

    .. note::
        We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
        :attr:`amax_history_len=1` in recipe.DelayedScaling currently. Other parameters
        in recipe.DelayedScaling would be ignored, even is set.

    Parameters
    ----------
    enabled: bool, default = False
             whether or not to enable fp8
    fp8_recipe: recipe.DelayedScaling, default = None
                recipe used for FP8 training.
    sharding_resource: ShardingResource, defaule = None
        specify the mesh axes for data and tensor parallelism to shard along.
        If set to None, then ShardingResource() would be created.
    """
    if fp8_recipe is None:
        fp8_recipe = DelayedScaling()

    assert fp8_recipe.amax_history_len == 1, \
        "It only support amax_history_len == 1 for now."

    if sharding_resource is None:
        sharding_resource = ShardingResource()

    try:
        with global_shard_guard(sharding_resource):
            if enabled:
                FP8Helper.initialize(margin=fp8_recipe.margin,
                                     fp8_format=fp8_recipe.fp8_format,
                                     update_fp8meta_interval=fp8_recipe.interval,
                                     amax_history_size=fp8_recipe.amax_history_len)
            yield
    finally:
        FP8Helper.finalize()