fp8.py 18.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
from contextlib import contextmanager
8
from enum import Enum
9
from typing import Dict, Optional, Tuple, Union
10

11
12
import jax
import jax.numpy as jnp
13
from flax.core.frozen_dict import FrozenDict
14
from flax.linen import fp8_ops
15

16
from transformer_engine_jax import DType
17
18
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
19
20
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
21
from transformer_engine.jax.sharding import MeshResource
22

23
24
_is_fp8_available = None
_reason_for_no_fp8 = ""
25
26
27
Collection = Union[Dict, FrozenDict]


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    gpu_arch = get_device_compute_capability(gpu_id)
    if gpu_arch >= 90:    # hopper and above
        return True, ""
    if gpu_arch < 89:    # pre-ada
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if get_cuda_version() < 12010:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    if gpu_id is not None:
        return _check_fp8_support(gpu_id)

    global _is_fp8_available, _reason_for_no_fp8
    if _is_fp8_available is None:
        _is_fp8_available = True
Frédéric Bastien's avatar
Frédéric Bastien committed
50
51
52
        # JAX doesn't provide the local GPU id.
        for local_gpu_id in range(len(jax.local_devices())):
            ret, msg = _check_fp8_support(local_gpu_id)
53
54
55
56
57
58
59
60
            if ret is False:
                _is_fp8_available = ret
                _reason_for_no_fp8 = msg
            break

    return _is_fp8_available, _reason_for_no_fp8


61
62
def _format2dtypes(format_: Format):
    if format_ == Format.E4M3:
63
        return jnp.float8_e4m3fn, jnp.float8_e4m3fn
64
    if format_ == Format.E5M2:
65
        return jnp.float8_e5m2, jnp.float8_e5m2
66
    if format_ == Format.HYBRID:
67
68
        return jnp.float8_e4m3fn, jnp.float8_e5m2
    return jnp.bfloat16, jnp.bfloat16
69
70


71
72
73
74
75
76
77
78
79
# fm32 is a custom dtype to specify the "add" rules as max operation.
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
# which is implemented via nn.scan. Without this custom dtype, nn.scan
# would sum gradients from all micro-batches, and this is not the expected
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
# be "MAX".
FlaxFloatMeta32 = fp8_ops.fm32


80
class FP8MetaPackage:
81
    """
82
    A container that contains all required meta data for FP8
83
84
85
86
87
88
89
90
91
92
    """

    def __init__(
        self,
        num_of_gemm: int,
        fp8_max: jnp.ndarray,
        amax: jnp.ndarray,
        scale: jnp.ndarray,
        scale_inv: jnp.ndarray,
    ) -> None:
93
        total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self._num_of_gemm = num_of_gemm
        assert fp8_max.shape[0] == total_num_of_meta
        self._fp8_max = fp8_max
        assert amax.shape[0] == total_num_of_meta
        self._amax = amax
        assert scale.shape[0] == total_num_of_meta
        self._scale = scale
        assert scale_inv.shape[0] == total_num_of_meta
        self._scale_inv = scale_inv

    @property
    def num_of_gemm(self) -> int:
        """
        num_of_gemm of this package
        """
        return self._num_of_gemm

    @property
    def fp8_max(self) -> jnp.ndarray:
        """
        fp8_max of this package
        """
        return self._fp8_max

    @property
    def amax(self) -> jnp.ndarray:
        """
        amax of this package
        """
        return self._amax

    @property
    def scale(self) -> jnp.ndarray:
        """
        scale of this package
        """
        return self._scale

    @property
    def scale_inv(self) -> jnp.ndarray:
        """
        scale_inv of this package
        """
        return self._scale_inv

139
140
141
142
143
144
145
146
147
148
149
150
151
    def get_package_by_gemm_idx(self, gemm_idx):
        """
        Get a sub package by gemm_idx
        """
        assert self.num_of_gemm > gemm_idx

        meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
        meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
        return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
                              self.amax[meta_start_idx:meta_end_idx],
                              self.scale[meta_start_idx:meta_end_idx],
                              self.scale_inv[meta_start_idx:meta_end_idx])

152

153
154
155
156
157
158
class AmaxComputeAlgo(Enum):
    """AmaxComputeAlgo."""
    MAX = "max"
    MOST_RECENT = "most_recent"


159
160
161
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"


162
163
164
165
166
167
168
class FP8Helper:
    """
    FP8 helper to manage the FP8 meta
    """
    INITIALIZED = False
    MARGIN: float = 0.0
    FP8_FORMAT: Format = Format.HYBRID
169
170
    FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
    BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
171
    UPDATE_FP8META_INTERVAL: int = 1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
172
173
    AMAX_HISTORY_LEN: int = 1024
    AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
174
175
176
177
    NUM_META_PER_GEMM: int = 3
    INPUT_META_IDX_PER_GEMM: int = 0
    KERNEL_META_IDX_PER_GEMM: int = 1
    GRAD_META_IDX_PER_GEMM: int = 2
178
    FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
179
180
181
182
183
    FP8_AMAX_NAME: str = "fp8_meta_amax"
    FP8_SCALE_NAME: str = "fp8_meta_scale"
    FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
    FP8_MAX_NAME: str = "fp8_max"
    FP8_2X_ACC_FPROP: bool = False
184
185
    FP8_2X_ACC_DGRAD: bool = True
    FP8_2X_ACC_WGRAD: bool = True
186
187

    @staticmethod
Ming-Xu Huang's avatar
Ming-Xu Huang committed
188
    def is_fp8_enabled():
189
190
191
192
193
194
195
196
197
        """
        Indicate if fp8 training is enable or not.
        """
        return FP8Helper.INITIALIZED

    @staticmethod
    def initialize(margin: float = 0.0,
                   fp8_format: Format = Format.HYBRID,
                   update_fp8meta_interval: int = 1,
198
                   amax_history_len: int = 1,
Frédéric Bastien's avatar
Frédéric Bastien committed
199
                   amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
200
201
202
203
204
205
206
207
208
        """
        Initialize the FP8 meta
        """
        FP8Helper.INITIALIZED = True
        FP8Helper.MARGIN = margin
        FP8Helper.FP8_FORMAT = fp8_format
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
            _format2dtypes(FP8Helper.FP8_FORMAT)
        FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
209
210
211
212
213
        FP8Helper.AMAX_HISTORY_LEN = amax_history_len
        FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
        FP8Helper.FP8_2X_ACC_FPROP = False
        FP8Helper.FP8_2X_ACC_DGRAD = True
        FP8Helper.FP8_2X_ACC_WGRAD = True
214
215
216
217
218
219
220
221
222

    @staticmethod
    def finalize() -> None:
        """
        FP8 helper finalize
        """
        FP8Helper.INITIALIZED = False
        FP8Helper.MARGIN = 0.0
        FP8Helper.FP8_FORMAT = Format.HYBRID
223
224
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
            _format2dtypes(FP8Helper.FP8_FORMAT)
225
        FP8Helper.UPDATE_FP8META_INTERVAL = 1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
226
227
        FP8Helper.AMAX_HISTORY_LEN = 1024
        FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
228

229
230
    @staticmethod
    def update_collections(new: Collection, original: Collection) -> Collection:
231
232
233
        """
        Update the collections
        """
234
235
236
        assert isinstance(original, (dict, FrozenDict))
        assert isinstance(new, (dict, FrozenDict))
        frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
237
        for key in new:
238
239
240
241
242
243
            if key in frozen_original:
                frozen_original, _ = frozen_original.pop(key)
        new_coll = FrozenDict({**new, **frozen_original})
        if not isinstance(original, FrozenDict):
            new_coll = new_coll.unfreeze()
        return new_coll
244
245
246
247
248
249

    @staticmethod
    def update_fp8_metas(state: Collection) -> Collection:
        """
        Update the FP8 metas
        """
250
        assert isinstance(state, (dict, FrozenDict))
251
        if FP8Helper.FP8_COLLECTION_NAME in state:
252
253
            frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state
            others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)
254
            fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
255
256
257
258
259
            new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})

            if not isinstance(state, FrozenDict):
                new_state = new_state.unfreeze()
            return new_state
260
261
262
263
264
265
266
267
        return state

    @staticmethod
    def generate_fp8_max_array(num_of_meta):
        """
        Generate the FP8 max array
        """
        num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
268
269
        fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
        fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
270
271
272
273
274
275
276
277
278
        fp8_max_per_gemm = []
        for i in range(FP8Helper.NUM_META_PER_GEMM):
            val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
                else fp8_max_fwd
            fp8_max_per_gemm.append([val])
        fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
        return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)

    @staticmethod
Jan Bielak's avatar
Jan Bielak committed
279
    def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        """
        Obtain the index about FP8 metas by the given GEMM index.
        """
        input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
        kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
        grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
        return input_idx, kernel_idx, grad_idx

    @staticmethod
    @jax.jit
    def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
        fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
        num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
        num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
        for i in range(num_of_gemm):
            # flattern array is ordered in alphabetical order of collection names
            fp8_max_idx = i * num_of_meta_with_max
            fp8_amax_idx = fp8_max_idx + 1
            fp8_scale_idx = fp8_amax_idx + 1
            fp8_scale_inv_idx = fp8_scale_idx + 1

            fp8_max = fp8_meta_arrays[fp8_max_idx]
302
            if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
303
                amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)
304
            else:
305
                amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
306
307
            scale = fp8_meta_arrays[fp8_scale_idx]

308
            sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
309
310
            sf = jnp.where(amax > 0.0, sf, scale)
            sf = jnp.where(jnp.isfinite(amax), sf, scale)
311
312
            fp8_meta_arrays[fp8_scale_idx] = sf
            fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
313
314
315

        return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    @staticmethod
    def generate_fp8_meta_dtype_converter_pair(*args):
        """
        Generate a pair of conversion fun in-between fm32 and fp32.
        """

        def identical_fun(*metas):
            return metas

        def fm32_to_fp32_fun(*metas):
            for meta in metas:
                assert meta.dtype == FlaxFloatMeta32
            return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]

        def fp32_to_fm32_fun(*metas):
            for meta in metas:
                assert meta.dtype == jnp.float32
            return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]

        # Make functions to be a vaild JAX type
        partial_identical_fun = jax.tree_util.Partial(identical_fun)
        partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)
        partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)

        if len(args) < 1:
            return partial_identical_fun, partial_identical_fun

        original_dtype = args[0].dtype
        for arg in args:
            assert arg.dtype == original_dtype

        if original_dtype == FlaxFloatMeta32:
            return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun

        return partial_identical_fun, partial_identical_fun

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    @staticmethod
    def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
        """
        Update the amax history
        """
        updated_amax = jnp.roll(amax, -1, -1)
        updated_amax = updated_amax.at[..., 0].set(0)
        return updated_amax

    @staticmethod
    @jax.jit
    def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
                         scale: jnp.ndarray) -> jnp.ndarray:
        """
        Calculate fp8 scale and scale_inv based on given amax.
        """
        if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
            amax = jnp.max(amax, axis=-1, keepdims=True)
        else:
            amax = amax[..., 0:1]

        sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
        sf = jnp.where(amax > 0.0, sf, scale)
        sf = jnp.where(jnp.isfinite(amax), sf, scale)
        scale = sf
        scale_inv = 1 / sf

        return scale, scale_inv

381
382
383
384

@contextmanager
def fp8_autocast(enabled: bool = False,
                 fp8_recipe: Optional[DelayedScaling] = None,
385
                 mesh_resource: Optional[MeshResource] = None) -> None:
386
    r"""
387
388
389
390
391
392
393
394
395
396
    Context manager for FP8 usage.

    .. code-block:: python

        mesh_shape = (4, 2)
        dp_mesh_axis_name = 'data_parallel'
        tp_mesh_axis_name = 'tensor_parallel'
        devices = np.asarray(jax.devices()).reshape(*mesh_shape)

        with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
397
            mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
398

399
            with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
400
401
402
403
404
405
406
                rules = extend_logical_axis_rules(tuple())
                transformer = TransformerLayer()

                with partitioning.axis_rules(rules):
                    pjit(transformer.init, ...)(...)

    .. note::
Frédéric Bastien's avatar
Frédéric Bastien committed
407
408
409
410
411
        We only support :attr:`margin`, :attr:`fp8_format`,
        :attr:`interval`, :attr:`amax_history_len` and
        :attr:`amax_compute_algo`(with value 'max' and 'most_recent')
        in recipe.DelayedScaling currently. Other parameters in
        recipe.DelayedScaling will trigger an assertion.
412
413
414
415

    Parameters
    ----------
    enabled: bool, default = False
416
        Whether or not to enable fp8
417
    fp8_recipe: recipe.DelayedScaling, default = None
418
        Recipe used for FP8 training.
419
    mesh_resource: MeshResource, default = None
420
        Specify the mesh axes for data and tensor parallelism to shard along.
Frédéric Bastien's avatar
Frédéric Bastien committed
421
422
        If set to None, then no data or tensor parallelism will be used.

423
424
425
426
    """
    if fp8_recipe is None:
        fp8_recipe = DelayedScaling()

427
428
429
    assert fp8_recipe.amax_compute_algo in [
        "max", "most_recent"
    ], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
Frédéric Bastien's avatar
Frédéric Bastien committed
430
431
432
433
    assert fp8_recipe.scaling_factor_compute_algo is None, (
        "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
    assert fp8_recipe.override_linear_precision == (False, False, False), (
        "DelayedScaling override_linear_precision isn't supported by TE/JAX.")
434
    assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
Frédéric Bastien's avatar
Frédéric Bastien committed
435

436
437
    if mesh_resource is None:
        mesh_resource = MeshResource()
438
439

    try:
440
        with global_shard_guard(mesh_resource):
441
            if enabled:
442
443
444
                fp8_available, reason_for_no_fp8 = is_fp8_available()
                assert fp8_available, reason_for_no_fp8

445
446
447
448
                amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
                if fp8_recipe.amax_compute_algo == 'max':
                    amax_compute_algo = AmaxComputeAlgo.MAX

449
450
451
                FP8Helper.initialize(margin=fp8_recipe.margin,
                                     fp8_format=fp8_recipe.fp8_format,
                                     update_fp8meta_interval=fp8_recipe.interval,
452
453
                                     amax_history_len=fp8_recipe.amax_history_len,
                                     amax_compute_algo=amax_compute_algo)
454
455
456
            yield
    finally:
        FP8Helper.finalize()
457
458
459


# Function Wrappers
460
def update_collections(new: Collection, original: Collection) -> FrozenDict:
461
    r"""
462
    A helper to update Flax's Collection.
463

464
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

    Parameters
    ----------
    new: Collection
        A collection that includes new data.
    original: Collection
        The base collection.

    Returns
    -------
    outputs : Collection
        The updated collection.
    """
    return FP8Helper.update_collections(new, original)


def update_fp8_metas(state: Collection) -> Collection:
    r"""
    Calculate new fp8 scales and its inverse via the followed formula

485
486
    .. code-block:: python

487
        sf = (fp8_max / amax) / (2 ^ margin)
488
        sf = sf if amax > 0.0, else original_scale
489
        updated_scale = sf if isfinite(amax), else original_scale)
490
        updated_scale_inv = 1/updated_scale
491

492
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
493
494
495
496
497
498
499
500
501
502
503
504

    Parameters
    ----------
    state: Collection
        A collection that includes FP8 metas.

    Returns
    -------
    outputs : Collection
        The collection with updated FP8 metas.
    """
    return FP8Helper.update_fp8_metas(state)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523


def get_delayed_scaling():
    r"""
    Obtain an instance of  DelayedScaling which is set via fp8_autocast.

    .. note::
        We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
        :attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
        Other parameters in recipe.DelayedScaling would be returned as the default
        values.

    Returns
    -------
    delay_scaling : DelayedScaling
        an instance of  DelayedScaling which is set via fp8_autocast.
    """
    amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
                        else "most_recent"
Jan Bielak's avatar
Jan Bielak committed
524
    return DelayedScaling(margin=int(FP8Helper.MARGIN),
Ming-Xu Huang's avatar
Ming-Xu Huang committed
525
526
527
528
                          interval=FP8Helper.UPDATE_FP8META_INTERVAL,
                          fp8_format=FP8Helper.FP8_FORMAT,
                          amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
                          amax_compute_algo=amax_compute_algo)