fp8.py 14.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
from contextlib import contextmanager
8
from enum import Enum
9
10
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
11

12
13
import jax
import jax.numpy as jnp
14
from flax.core.frozen_dict import FrozenDict
15
from flax.linen import fp8_ops
16

17
18
19
20
21
22
from transformer_engine.transformer_engine_jax import DType
from transformer_engine.transformer_engine_jax import get_cublasLt_version
from transformer_engine.transformer_engine_jax import (
    get_cuda_version,
    get_device_compute_capability,
)
23
24
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
25
from transformer_engine.jax.sharding import MeshResource
26

27
28
_is_fp8_available = None
_reason_for_no_fp8 = ""
29
30
31
Collection = Union[Dict, FrozenDict]


32
33
34
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    gpu_arch = get_device_compute_capability(gpu_id)
35
    if gpu_arch >= 90:  # hopper and above
36
        return True, ""
37
    if gpu_arch < 89:  # pre-ada
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if get_cuda_version() < 12010:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    if gpu_id is not None:
        return _check_fp8_support(gpu_id)

    global _is_fp8_available, _reason_for_no_fp8
    if _is_fp8_available is None:
        _is_fp8_available = True
Frédéric Bastien's avatar
Frédéric Bastien committed
54
55
56
        # JAX doesn't provide the local GPU id.
        for local_gpu_id in range(len(jax.local_devices())):
            ret, msg = _check_fp8_support(local_gpu_id)
57
58
59
60
61
62
63
64
            if ret is False:
                _is_fp8_available = ret
                _reason_for_no_fp8 = msg
            break

    return _is_fp8_available, _reason_for_no_fp8


65
66
def _format2dtypes(format_: Format):
    if format_ == Format.E4M3:
67
        return jnp.float8_e4m3fn, jnp.float8_e4m3fn
68
    if format_ == Format.E5M2:
69
        return jnp.float8_e5m2, jnp.float8_e5m2
70
    if format_ == Format.HYBRID:
71
72
        return jnp.float8_e4m3fn, jnp.float8_e5m2
    return jnp.bfloat16, jnp.bfloat16
73
74


75
76
77
78
79
80
81
82
83
# fm32 is a custom dtype to specify the "add" rules as max operation.
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
# which is implemented via nn.scan. Without this custom dtype, nn.scan
# would sum gradients from all micro-batches, and this is not the expected
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
# be "MAX".
FlaxFloatMeta32 = fp8_ops.fm32


84
class FP8MetaPackage:
85
    """
86
    A container that contains all required meta data for FP8
87
88
    """

89
90
91
92
93
    NUM_OF_META: int = 3
    INPUT_IDX: int = 0
    WEIGHT_IDX: int = 1
    GRAD_IDX: int = 2

94
95
    def __init__(
        self,
96
97
98
99
100
101
        input_amax: jnp.ndarray,
        input_scale: jnp.ndarray,
        weight_amax: jnp.ndarray,
        weight_scale: jnp.ndarray,
        grad_amax: jnp.ndarray,
        grad_scale: jnp.ndarray,
102
103
    ) -> None:

104
105
        self._amax_list = [None] * FP8MetaPackage.NUM_OF_META
        self._scale_list = [None] * FP8MetaPackage.NUM_OF_META
106

107
108
109
110
111
112
        self._amax_list[FP8MetaPackage.INPUT_IDX] = input_amax
        self._scale_list[FP8MetaPackage.INPUT_IDX] = input_scale
        self._amax_list[FP8MetaPackage.WEIGHT_IDX] = weight_amax
        self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale
        self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax
        self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale
113
114

    @property
115
    def amax_list(self) -> List[jnp.ndarray]:
116
        """
117
        Get the amax list of this package.
118
        """
119
        return self._amax_list
120
121

    @property
122
    def scale_list(self) -> List[jnp.ndarray]:
123
        """
124
        Get the scale list of this package.
125
        """
126
        return self._scale_list
127

128
129
    @staticmethod
    def update_amax_list(amax_list: List[jnp.ndarray]) -> jnp.ndarray:
130
        """
131
        Update the amax history list
132
        """
133
134
        updated_amax_list = [FP8Helper.update_amax_history(amax) for amax in amax_list]
        return updated_amax_list
135

136
137
    @staticmethod
    def update_fp8_scale(
138
139
        amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], fp8_dtype_list: List[DType]
    ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
140
        """
141
        Get update scale and scale_inv list
142
        """
143
144
145
146
147
148
149
        update_scale_list = []
        update_scale_inv_list = []
        for amax, scale, fp8_dtype in zip(amax_list, scale_list, fp8_dtype_list):
            upadted_scale, updated_scale_inv = FP8Helper.update_fp8_scale(amax, scale, fp8_dtype)
            update_scale_list.append(upadted_scale)
            update_scale_inv_list.append(updated_scale_inv)
        return update_scale_list, update_scale_inv_list
150

151

152
153
class AmaxComputeAlgo(Enum):
    """AmaxComputeAlgo."""
154

155
156
157
158
    MAX = "max"
    MOST_RECENT = "most_recent"


159
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
160
161


162
163
164
165
class FP8Helper:
    """
    FP8 helper to manage the FP8 meta
    """
166

167
168
169
    INITIALIZED = False
    MARGIN: float = 0.0
    FP8_FORMAT: Format = Format.HYBRID
170
171
    FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
    BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
Ming-Xu Huang's avatar
Ming-Xu Huang committed
172
173
    AMAX_HISTORY_LEN: int = 1024
    AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
174
    FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
175
176
    FP8_AMAX_NAME: str = "amax"
    FP8_SCALE_NAME: str = "scale"
177
    FP8_2X_ACC_FPROP: bool = False
178
179
    FP8_2X_ACC_DGRAD: bool = True
    FP8_2X_ACC_WGRAD: bool = True
180
181

    @staticmethod
Ming-Xu Huang's avatar
Ming-Xu Huang committed
182
    def is_fp8_enabled():
183
184
185
186
187
188
        """
        Indicate if fp8 training is enable or not.
        """
        return FP8Helper.INITIALIZED

    @staticmethod
189
190
191
192
193
194
    def initialize(
        margin: float = 0.0,
        fp8_format: Format = Format.HYBRID,
        amax_history_len: int = 1,
        amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX,
    ) -> None:
195
196
197
198
199
200
        """
        Initialize the FP8 meta
        """
        FP8Helper.INITIALIZED = True
        FP8Helper.MARGIN = margin
        FP8Helper.FP8_FORMAT = fp8_format
201
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT)
202
203
204
205
206
        FP8Helper.AMAX_HISTORY_LEN = amax_history_len
        FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
        FP8Helper.FP8_2X_ACC_FPROP = False
        FP8Helper.FP8_2X_ACC_DGRAD = True
        FP8Helper.FP8_2X_ACC_WGRAD = True
207
208
209
210
211
212
213
214
215

    @staticmethod
    def finalize() -> None:
        """
        FP8 helper finalize
        """
        FP8Helper.INITIALIZED = False
        FP8Helper.MARGIN = 0.0
        FP8Helper.FP8_FORMAT = Format.HYBRID
216
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
217
218
        FP8Helper.AMAX_HISTORY_LEN = 1024
        FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
219

220
221
    @staticmethod
    def update_collections(new: Collection, original: Collection) -> Collection:
222
223
224
        """
        Update the collections
        """
225
226
227
        assert isinstance(original, (dict, FrozenDict))
        assert isinstance(new, (dict, FrozenDict))
        frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
228
        for key in new:
229
230
231
232
233
234
            if key in frozen_original:
                frozen_original, _ = frozen_original.pop(key)
        new_coll = FrozenDict({**new, **frozen_original})
        if not isinstance(original, FrozenDict):
            new_coll = new_coll.unfreeze()
        return new_coll
235

236
237
238
239
240
241
242
    @staticmethod
    def generate_fp8_meta_dtype_converter_pair(*args):
        """
        Generate a pair of conversion fun in-between fm32 and fp32.
        """

        def identical_fun(*metas):
243
            return list(metas)
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

        def fm32_to_fp32_fun(*metas):
            for meta in metas:
                assert meta.dtype == FlaxFloatMeta32
            return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]

        def fp32_to_fm32_fun(*metas):
            for meta in metas:
                assert meta.dtype == jnp.float32
            return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]

        # Make functions to be a vaild JAX type
        partial_identical_fun = jax.tree_util.Partial(identical_fun)
        partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)
        partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)

        if len(args) < 1:
            return partial_identical_fun, partial_identical_fun

        original_dtype = args[0].dtype
        for arg in args:
            assert arg.dtype == original_dtype

        if original_dtype == FlaxFloatMeta32:
            return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun

        return partial_identical_fun, partial_identical_fun

272
    @staticmethod
273
    @jax.jit
274
275
276
277
278
    def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
        """
        Update the amax history
        """
        updated_amax = jnp.roll(amax, -1, -1)
279
        updated_amax = updated_amax.at[0].set(0)
280
281
282
        return updated_amax

    @staticmethod
283
284
    @partial(jax.jit, static_argnums=(2,))
    def update_fp8_scale(amax: jnp.ndarray, scale: jnp.ndarray, fp8_dtype: DType) -> jnp.ndarray:
285
286
287
        """
        Calculate fp8 scale and scale_inv based on given amax.
        """
288
289
        fp8_max = jnp.astype(jnp.finfo(fp8_dtype).max, jnp.float32)

290
291
292
        if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
            amax = jnp.max(amax, axis=-1, keepdims=True)
        else:
293
            amax = amax[0:1]
294
295
296
297
298
299
300
301
302

        sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
        sf = jnp.where(amax > 0.0, sf, scale)
        sf = jnp.where(jnp.isfinite(amax), sf, scale)
        scale = sf
        scale_inv = 1 / sf

        return scale, scale_inv

303
304

@contextmanager
305
306
307
308
309
def fp8_autocast(
    enabled: bool = False,
    fp8_recipe: Optional[DelayedScaling] = None,
    mesh_resource: Optional[MeshResource] = None,
) -> None:
310
    r"""
311
312
313
314
315
316
317
318
319
320
    Context manager for FP8 usage.

    .. code-block:: python

        mesh_shape = (4, 2)
        dp_mesh_axis_name = 'data_parallel'
        tp_mesh_axis_name = 'tensor_parallel'
        devices = np.asarray(jax.devices()).reshape(*mesh_shape)

        with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
321
            mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
322

323
            with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
324
325
326
327
328
329
330
                rules = extend_logical_axis_rules(tuple())
                transformer = TransformerLayer()

                with partitioning.axis_rules(rules):
                    pjit(transformer.init, ...)(...)

    .. note::
331
332
333
334
        We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
        , and :attr:`amax_compute_algo`(with value 'max' and 'most_recent') in
        recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
        will trigger an assertion.
335
336
337
338

    Parameters
    ----------
    enabled: bool, default = False
339
        Whether or not to enable fp8
340
    fp8_recipe: recipe.DelayedScaling, default = None
341
        Recipe used for FP8 training.
342
    mesh_resource: MeshResource, default = None
343
        Specify the mesh axes for data and tensor parallelism to shard along.
Frédéric Bastien's avatar
Frédéric Bastien committed
344
345
        If set to None, then no data or tensor parallelism will be used.

346
347
348
349
    """
    if fp8_recipe is None:
        fp8_recipe = DelayedScaling()

350
    assert fp8_recipe.amax_compute_algo in [
351
352
353
354
355
356
357
358
359
360
361
362
        "max",
        "most_recent",
    ], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
    assert (
        fp8_recipe.scaling_factor_compute_algo is None
    ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
    assert fp8_recipe.override_linear_precision == (
        False,
        False,
        False,
    ), "DelayedScaling override_linear_precision isn't supported by TE/JAX."
    assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
Frédéric Bastien's avatar
Frédéric Bastien committed
363

364
365
    if mesh_resource is None:
        mesh_resource = MeshResource()
366
367

    try:
368
        with global_shard_guard(mesh_resource):
369
            if enabled:
370
371
372
                fp8_available, reason_for_no_fp8 = is_fp8_available()
                assert fp8_available, reason_for_no_fp8

373
                amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
374
                if fp8_recipe.amax_compute_algo == "max":
375
376
                    amax_compute_algo = AmaxComputeAlgo.MAX

377
378
379
380
381
382
                FP8Helper.initialize(
                    margin=fp8_recipe.margin,
                    fp8_format=fp8_recipe.fp8_format,
                    amax_history_len=fp8_recipe.amax_history_len,
                    amax_compute_algo=amax_compute_algo,
                )
383
384
385
            yield
    finally:
        FP8Helper.finalize()
386
387
388


# Function Wrappers
389
def update_collections(new: Collection, original: Collection) -> FrozenDict:
390
    r"""
391
    A helper to update Flax's Collection.
392

393
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

    Parameters
    ----------
    new: Collection
        A collection that includes new data.
    original: Collection
        The base collection.

    Returns
    -------
    outputs : Collection
        The updated collection.
    """
    return FP8Helper.update_collections(new, original)


Ming-Xu Huang's avatar
Ming-Xu Huang committed
410
411
412
413
414
def get_delayed_scaling():
    r"""
    Obtain an instance of  DelayedScaling which is set via fp8_autocast.

    .. note::
415
416
417
        We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
        , and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
        recipe.DelayedScaling would be returned as the default values.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
418
419
420
421
422
423

    Returns
    -------
    delay_scaling : DelayedScaling
        an instance of  DelayedScaling which is set via fp8_autocast.
    """
424
425
426
427
428
429
430
431
432
    amax_compute_algo = (
        "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
    )
    return DelayedScaling(
        margin=int(FP8Helper.MARGIN),
        fp8_format=FP8Helper.FP8_FORMAT,
        amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
        amax_compute_algo=amax_compute_algo,
    )