fp8.py 16.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
from contextlib import contextmanager
8
from enum import Enum
9
from typing import Dict, Optional, Tuple, Union
10

11
12
import jax
import jax.numpy as jnp
13
14
from flax.core.frozen_dict import FrozenDict

15
from transformer_engine_jax import DType
16
17
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
18
19
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
20
from transformer_engine.jax.sharding import MeshResource
21

22
23
_is_fp8_available = None
_reason_for_no_fp8 = ""
24
25
26
Collection = Union[Dict, FrozenDict]


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    gpu_arch = get_device_compute_capability(gpu_id)
    if gpu_arch >= 90:    # hopper and above
        return True, ""
    if gpu_arch < 89:    # pre-ada
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if get_cuda_version() < 12010:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    if gpu_id is not None:
        return _check_fp8_support(gpu_id)

    global _is_fp8_available, _reason_for_no_fp8
    if _is_fp8_available is None:
        _is_fp8_available = True
Frédéric Bastien's avatar
Frédéric Bastien committed
49
50
51
        # JAX doesn't provide the local GPU id.
        for local_gpu_id in range(len(jax.local_devices())):
            ret, msg = _check_fp8_support(local_gpu_id)
52
53
54
55
56
57
58
59
            if ret is False:
                _is_fp8_available = ret
                _reason_for_no_fp8 = msg
            break

    return _is_fp8_available, _reason_for_no_fp8


60
61
def _format2dtypes(format_: Format):
    if format_ == Format.E4M3:
62
        return jnp.float8_e4m3fn, jnp.float8_e4m3fn
63
    if format_ == Format.E5M2:
64
        return jnp.float8_e5m2, jnp.float8_e5m2
65
    if format_ == Format.HYBRID:
66
67
        return jnp.float8_e4m3fn, jnp.float8_e5m2
    return jnp.bfloat16, jnp.bfloat16
68
69


70
class FP8MetaPackage:
71
    """
72
    A container that contains all required meta data for FP8
73
74
75
76
77
78
79
80
81
82
    """

    def __init__(
        self,
        num_of_gemm: int,
        fp8_max: jnp.ndarray,
        amax: jnp.ndarray,
        scale: jnp.ndarray,
        scale_inv: jnp.ndarray,
    ) -> None:
83
        total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        self._num_of_gemm = num_of_gemm
        assert fp8_max.shape[0] == total_num_of_meta
        self._fp8_max = fp8_max
        assert amax.shape[0] == total_num_of_meta
        self._amax = amax
        assert scale.shape[0] == total_num_of_meta
        self._scale = scale
        assert scale_inv.shape[0] == total_num_of_meta
        self._scale_inv = scale_inv

    @property
    def num_of_gemm(self) -> int:
        """
        num_of_gemm of this package
        """
        return self._num_of_gemm

    @property
    def fp8_max(self) -> jnp.ndarray:
        """
        fp8_max of this package
        """
        return self._fp8_max

    @property
    def amax(self) -> jnp.ndarray:
        """
        amax of this package
        """
        return self._amax

    @property
    def scale(self) -> jnp.ndarray:
        """
        scale of this package
        """
        return self._scale

    @property
    def scale_inv(self) -> jnp.ndarray:
        """
        scale_inv of this package
        """
        return self._scale_inv

129
130
131
132
133
134
135
136
137
138
139
140
141
    def get_package_by_gemm_idx(self, gemm_idx):
        """
        Get a sub package by gemm_idx
        """
        assert self.num_of_gemm > gemm_idx

        meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
        meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
        return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
                              self.amax[meta_start_idx:meta_end_idx],
                              self.scale[meta_start_idx:meta_end_idx],
                              self.scale_inv[meta_start_idx:meta_end_idx])

142

143
144
145
146
147
148
class AmaxComputeAlgo(Enum):
    """AmaxComputeAlgo."""
    MAX = "max"
    MOST_RECENT = "most_recent"


149
150
151
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"


152
153
154
155
156
157
158
class FP8Helper:
    """
    FP8 helper to manage the FP8 meta
    """
    INITIALIZED = False
    MARGIN: float = 0.0
    FP8_FORMAT: Format = Format.HYBRID
159
160
    FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
    BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
161
    UPDATE_FP8META_INTERVAL: int = 1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
162
163
    AMAX_HISTORY_LEN: int = 1024
    AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
164
165
166
167
    NUM_META_PER_GEMM: int = 3
    INPUT_META_IDX_PER_GEMM: int = 0
    KERNEL_META_IDX_PER_GEMM: int = 1
    GRAD_META_IDX_PER_GEMM: int = 2
168
    FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
169
170
171
172
173
    FP8_AMAX_NAME: str = "fp8_meta_amax"
    FP8_SCALE_NAME: str = "fp8_meta_scale"
    FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
    FP8_MAX_NAME: str = "fp8_max"
    FP8_2X_ACC_FPROP: bool = False
174
175
    FP8_2X_ACC_DGRAD: bool = True
    FP8_2X_ACC_WGRAD: bool = True
176
177

    @staticmethod
Ming-Xu Huang's avatar
Ming-Xu Huang committed
178
    def is_fp8_enabled():
179
180
181
182
183
184
185
186
187
        """
        Indicate if fp8 training is enable or not.
        """
        return FP8Helper.INITIALIZED

    @staticmethod
    def initialize(margin: float = 0.0,
                   fp8_format: Format = Format.HYBRID,
                   update_fp8meta_interval: int = 1,
188
                   amax_history_len: int = 1,
Frédéric Bastien's avatar
Frédéric Bastien committed
189
                   amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
190
191
192
193
194
195
196
197
198
        """
        Initialize the FP8 meta
        """
        FP8Helper.INITIALIZED = True
        FP8Helper.MARGIN = margin
        FP8Helper.FP8_FORMAT = fp8_format
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
            _format2dtypes(FP8Helper.FP8_FORMAT)
        FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
199
200
201
202
203
        FP8Helper.AMAX_HISTORY_LEN = amax_history_len
        FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
        FP8Helper.FP8_2X_ACC_FPROP = False
        FP8Helper.FP8_2X_ACC_DGRAD = True
        FP8Helper.FP8_2X_ACC_WGRAD = True
204
205
206
207
208
209
210
211
212

    @staticmethod
    def finalize() -> None:
        """
        FP8 helper finalize
        """
        FP8Helper.INITIALIZED = False
        FP8Helper.MARGIN = 0.0
        FP8Helper.FP8_FORMAT = Format.HYBRID
213
214
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
            _format2dtypes(FP8Helper.FP8_FORMAT)
215
        FP8Helper.UPDATE_FP8META_INTERVAL = 1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
216
217
        FP8Helper.AMAX_HISTORY_LEN = 1024
        FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
218

219
220
    @staticmethod
    def update_collections(new: Collection, original: Collection) -> Collection:
221
222
223
        """
        Update the collections
        """
224
225
226
        assert isinstance(original, (dict, FrozenDict))
        assert isinstance(new, (dict, FrozenDict))
        frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
227
        for key in new:
228
229
230
231
232
233
            if key in frozen_original:
                frozen_original, _ = frozen_original.pop(key)
        new_coll = FrozenDict({**new, **frozen_original})
        if not isinstance(original, FrozenDict):
            new_coll = new_coll.unfreeze()
        return new_coll
234
235
236
237
238
239

    @staticmethod
    def update_fp8_metas(state: Collection) -> Collection:
        """
        Update the FP8 metas
        """
240
        assert isinstance(state, (dict, FrozenDict))
241
        if FP8Helper.FP8_COLLECTION_NAME in state:
242
243
            frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state
            others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)
244
            fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
245
246
247
248
249
            new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})

            if not isinstance(state, FrozenDict):
                new_state = new_state.unfreeze()
            return new_state
250
251
252
253
254
255
256
257
        return state

    @staticmethod
    def generate_fp8_max_array(num_of_meta):
        """
        Generate the FP8 max array
        """
        num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
258
259
        fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
        fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
260
261
262
263
264
265
266
267
268
        fp8_max_per_gemm = []
        for i in range(FP8Helper.NUM_META_PER_GEMM):
            val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
                else fp8_max_fwd
            fp8_max_per_gemm.append([val])
        fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
        return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)

    @staticmethod
Jan Bielak's avatar
Jan Bielak committed
269
    def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        """
        Obtain the index about FP8 metas by the given GEMM index.
        """
        input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
        kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
        grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
        return input_idx, kernel_idx, grad_idx

    @staticmethod
    @jax.jit
    def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
        fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
        num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
        num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
        for i in range(num_of_gemm):
            # flattern array is ordered in alphabetical order of collection names
            fp8_max_idx = i * num_of_meta_with_max
            fp8_amax_idx = fp8_max_idx + 1
            fp8_scale_idx = fp8_amax_idx + 1
            fp8_scale_inv_idx = fp8_scale_idx + 1

            fp8_max = fp8_meta_arrays[fp8_max_idx]
292
            if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
293
                amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)
294
            else:
295
                amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
296
297
            scale = fp8_meta_arrays[fp8_scale_idx]

298
            sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
299
300
            sf = jnp.where(amax > 0.0, sf, scale)
            sf = jnp.where(jnp.isfinite(amax), sf, scale)
301
302
            fp8_meta_arrays[fp8_scale_idx] = sf
            fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
303
304
305

        return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    @staticmethod
    def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
        """
        Update the amax history
        """
        updated_amax = jnp.roll(amax, -1, -1)
        updated_amax = updated_amax.at[..., 0].set(0)
        return updated_amax

    @staticmethod
    @jax.jit
    def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
                         scale: jnp.ndarray) -> jnp.ndarray:
        """
        Calculate fp8 scale and scale_inv based on given amax.
        """
        if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
            amax = jnp.max(amax, axis=-1, keepdims=True)
        else:
            amax = amax[..., 0:1]

        sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
        sf = jnp.where(amax > 0.0, sf, scale)
        sf = jnp.where(jnp.isfinite(amax), sf, scale)
        scale = sf
        scale_inv = 1 / sf

        return scale, scale_inv

335
336
337
338

@contextmanager
def fp8_autocast(enabled: bool = False,
                 fp8_recipe: Optional[DelayedScaling] = None,
339
                 mesh_resource: Optional[MeshResource] = None) -> None:
340
    r"""
341
342
343
344
345
346
347
348
349
350
    Context manager for FP8 usage.

    .. code-block:: python

        mesh_shape = (4, 2)
        dp_mesh_axis_name = 'data_parallel'
        tp_mesh_axis_name = 'tensor_parallel'
        devices = np.asarray(jax.devices()).reshape(*mesh_shape)

        with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
351
            mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
352

353
            with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
354
355
356
357
358
359
360
                rules = extend_logical_axis_rules(tuple())
                transformer = TransformerLayer()

                with partitioning.axis_rules(rules):
                    pjit(transformer.init, ...)(...)

    .. note::
Frédéric Bastien's avatar
Frédéric Bastien committed
361
362
363
364
365
        We only support :attr:`margin`, :attr:`fp8_format`,
        :attr:`interval`, :attr:`amax_history_len` and
        :attr:`amax_compute_algo`(with value 'max' and 'most_recent')
        in recipe.DelayedScaling currently. Other parameters in
        recipe.DelayedScaling will trigger an assertion.
366
367
368
369

    Parameters
    ----------
    enabled: bool, default = False
370
        Whether or not to enable fp8
371
    fp8_recipe: recipe.DelayedScaling, default = None
372
        Recipe used for FP8 training.
373
    mesh_resource: MeshResource, default = None
374
        Specify the mesh axes for data and tensor parallelism to shard along.
Frédéric Bastien's avatar
Frédéric Bastien committed
375
376
        If set to None, then no data or tensor parallelism will be used.

377
378
379
380
    """
    if fp8_recipe is None:
        fp8_recipe = DelayedScaling()

381
382
383
    assert fp8_recipe.amax_compute_algo in [
        "max", "most_recent"
    ], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
Frédéric Bastien's avatar
Frédéric Bastien committed
384
385
386
387
    assert fp8_recipe.scaling_factor_compute_algo is None, (
        "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
    assert fp8_recipe.override_linear_precision == (False, False, False), (
        "DelayedScaling override_linear_precision isn't supported by TE/JAX.")
388
    assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
Frédéric Bastien's avatar
Frédéric Bastien committed
389

390
391
    if mesh_resource is None:
        mesh_resource = MeshResource()
392
393

    try:
394
        with global_shard_guard(mesh_resource):
395
            if enabled:
396
397
398
                fp8_available, reason_for_no_fp8 = is_fp8_available()
                assert fp8_available, reason_for_no_fp8

399
400
401
402
                amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
                if fp8_recipe.amax_compute_algo == 'max':
                    amax_compute_algo = AmaxComputeAlgo.MAX

403
404
405
                FP8Helper.initialize(margin=fp8_recipe.margin,
                                     fp8_format=fp8_recipe.fp8_format,
                                     update_fp8meta_interval=fp8_recipe.interval,
406
407
                                     amax_history_len=fp8_recipe.amax_history_len,
                                     amax_compute_algo=amax_compute_algo)
408
409
410
            yield
    finally:
        FP8Helper.finalize()
411
412
413


# Function Wrappers
414
def update_collections(new: Collection, original: Collection) -> FrozenDict:
415
    r"""
416
    A helper to update Flax's Collection.
417

418
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

    Parameters
    ----------
    new: Collection
        A collection that includes new data.
    original: Collection
        The base collection.

    Returns
    -------
    outputs : Collection
        The updated collection.
    """
    return FP8Helper.update_collections(new, original)


def update_fp8_metas(state: Collection) -> Collection:
    r"""
    Calculate new fp8 scales and its inverse via the followed formula

439
440
    .. code-block:: python

441
        sf = (fp8_max / amax) / (2 ^ margin)
442
        sf = sf if amax > 0.0, else original_scale
443
        updated_scale = sf if isfinite(amax), else original_scale)
444
        updated_scale_inv = 1/updated_scale
445

446
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
447
448
449
450
451
452
453
454
455
456
457
458

    Parameters
    ----------
    state: Collection
        A collection that includes FP8 metas.

    Returns
    -------
    outputs : Collection
        The collection with updated FP8 metas.
    """
    return FP8Helper.update_fp8_metas(state)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477


def get_delayed_scaling():
    r"""
    Obtain an instance of  DelayedScaling which is set via fp8_autocast.

    .. note::
        We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
        :attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
        Other parameters in recipe.DelayedScaling would be returned as the default
        values.

    Returns
    -------
    delay_scaling : DelayedScaling
        an instance of  DelayedScaling which is set via fp8_autocast.
    """
    amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
                        else "most_recent"
Jan Bielak's avatar
Jan Bielak committed
478
    return DelayedScaling(margin=int(FP8Helper.MARGIN),
Ming-Xu Huang's avatar
Ming-Xu Huang committed
479
480
481
482
                          interval=FP8Helper.UPDATE_FP8META_INTERVAL,
                          fp8_format=FP8Helper.FP8_FORMAT,
                          amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
                          amax_compute_algo=amax_compute_algo)