fp8.py 15.4 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
from contextlib import contextmanager
8
9
10
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union

11
12
import jax
import jax.numpy as jnp
13
14
from flax.core.frozen_dict import FrozenDict

15
from transformer_engine_jax import DType
16
17
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
18
19
20
21
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource

22
23
_is_fp8_available = None
_reason_for_no_fp8 = ""
24
25
26
Collection = Union[Dict, FrozenDict]


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    gpu_arch = get_device_compute_capability(gpu_id)
    if gpu_arch >= 90:    # hopper and above
        return True, ""
    if gpu_arch < 89:    # pre-ada
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if get_cuda_version() < 12010:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    if gpu_id is not None:
        return _check_fp8_support(gpu_id)

    global _is_fp8_available, _reason_for_no_fp8
    if _is_fp8_available is None:
        _is_fp8_available = True
Frédéric Bastien's avatar
Frédéric Bastien committed
49
50
51
        # JAX doesn't provide the local GPU id.
        for local_gpu_id in range(len(jax.local_devices())):
            ret, msg = _check_fp8_support(local_gpu_id)
52
53
54
55
56
57
58
59
            if ret is False:
                _is_fp8_available = ret
                _reason_for_no_fp8 = msg
            break

    return _is_fp8_available, _reason_for_no_fp8


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def _format2dtypes(format_: Format):
    if format_ == Format.E4M3:
        return DType.kFloat8E4M3, DType.kFloat8E4M3
    if format_ == Format.E5M2:
        return DType.kFloat8E5M2, DType.kFloat8E5M2
    if format_ == Format.HYBRID:
        return DType.kFloat8E4M3, DType.kFloat8E5M2
    return DType.kBFloat16, DType.kBFloat16


class FP8GemmPackage:
    """
    A container that contains all required data for
    FP8 GEMM
    """

    def __init__(
        self,
        num_of_gemm: int,
        inputs: jnp.ndarray,
        kernels: List[jnp.ndarray],
        fp8_max: jnp.ndarray,
        amax: jnp.ndarray,
        scale: jnp.ndarray,
        scale_inv: jnp.ndarray,
    ) -> None:
        self._num_of_gemm = num_of_gemm
        self._inputs = inputs

        assert len(kernels) == self._num_of_gemm
        self._kernels = kernels

        total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
        assert fp8_max.shape[0] == total_num_of_meta
        self._fp8_max = fp8_max
        assert amax.shape[0] == total_num_of_meta
        self._amax = amax
        assert scale.shape[0] == total_num_of_meta
        self._scale = scale
        assert scale_inv.shape[0] == total_num_of_meta
        self._scale_inv = scale_inv

    @property
    def num_of_gemm(self) -> int:
        """
        num_of_gemm of this package
        """
        return self._num_of_gemm

    @property
    def inputs(self) -> jnp.ndarray:
        """
        inputs of this package
        """
        return self._inputs

    @property
    def kernels(self) -> List[jnp.ndarray]:
        """
        kernels of this package
        """
        return self._kernels

    @property
    def fp8_max(self) -> jnp.ndarray:
        """
        fp8_max of this package
        """
        return self._fp8_max

    @property
    def amax(self) -> jnp.ndarray:
        """
        amax of this package
        """
        return self._amax

    @property
    def scale(self) -> jnp.ndarray:
        """
        scale of this package
        """
        return self._scale

    @property
    def scale_inv(self) -> jnp.ndarray:
        """
        scale_inv of this package
        """
        return self._scale_inv


152
153
154
155
156
157
class AmaxComputeAlgo(Enum):
    """AmaxComputeAlgo."""
    MAX = "max"
    MOST_RECENT = "most_recent"


158
159
160
161
162
163
164
165
166
167
class FP8Helper:
    """
    FP8 helper to manage the FP8 meta
    """
    INITIALIZED = False
    MARGIN: float = 0.0
    FP8_FORMAT: Format = Format.HYBRID
    FWD_DTYPE: DType = DType.kFloat8E4M3
    BWD_DTYPE: DType = DType.kFloat8E5M2
    UPDATE_FP8META_INTERVAL: int = 1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
168
169
    AMAX_HISTORY_LEN: int = 1024
    AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
170
171
172
173
174
175
176
177
178
179
    NUM_META_PER_GEMM: int = 3
    INPUT_META_IDX_PER_GEMM: int = 0
    KERNEL_META_IDX_PER_GEMM: int = 1
    GRAD_META_IDX_PER_GEMM: int = 2
    FP8_COLLECTION_NAME: str = "fp8_meta_collection"
    FP8_AMAX_NAME: str = "fp8_meta_amax"
    FP8_SCALE_NAME: str = "fp8_meta_scale"
    FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
    FP8_MAX_NAME: str = "fp8_max"
    FP8_2X_ACC_FPROP: bool = False
180
181
    FP8_2X_ACC_DGRAD: bool = True
    FP8_2X_ACC_WGRAD: bool = True
182
183

    @staticmethod
Ming-Xu Huang's avatar
Ming-Xu Huang committed
184
    def is_fp8_enabled():
185
186
187
188
189
190
191
192
193
        """
        Indicate if fp8 training is enable or not.
        """
        return FP8Helper.INITIALIZED

    @staticmethod
    def initialize(margin: float = 0.0,
                   fp8_format: Format = Format.HYBRID,
                   update_fp8meta_interval: int = 1,
194
195
                   amax_history_len: int = 1,
                   amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT) -> None:
196
197
198
199
200
201
202
203
204
        """
        Initialize the FP8 meta
        """
        FP8Helper.INITIALIZED = True
        FP8Helper.MARGIN = margin
        FP8Helper.FP8_FORMAT = fp8_format
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
            _format2dtypes(FP8Helper.FP8_FORMAT)
        FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
205
206
207
208
209
        FP8Helper.AMAX_HISTORY_LEN = amax_history_len
        FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
        FP8Helper.FP8_2X_ACC_FPROP = False
        FP8Helper.FP8_2X_ACC_DGRAD = True
        FP8Helper.FP8_2X_ACC_WGRAD = True
210
211
212
213
214
215
216
217
218
219
220
221

    @staticmethod
    def finalize() -> None:
        """
        FP8 helper finalize
        """
        FP8Helper.INITIALIZED = False
        FP8Helper.MARGIN = 0.0
        FP8Helper.FP8_FORMAT = Format.HYBRID
        FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
        FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
        FP8Helper.UPDATE_FP8META_INTERVAL = 1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
222
223
        FP8Helper.AMAX_HISTORY_LEN = 1024
        FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
224
225

    @staticmethod
226
227
228
229
230
    def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
        """
        Update the amax history
        """
        updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
231
        updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
232
233
234
235
        return updated_amax_buffers

    @staticmethod
    def update_collections(new: Collection, original: Collection) -> Collection:
236
237
238
        """
        Update the collections
        """
239
240
241
        assert isinstance(original, (dict, FrozenDict))
        assert isinstance(new, (dict, FrozenDict))
        frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
242
        for key in new:
243
244
245
246
247
248
            if key in frozen_original:
                frozen_original, _ = frozen_original.pop(key)
        new_coll = FrozenDict({**new, **frozen_original})
        if not isinstance(original, FrozenDict):
            new_coll = new_coll.unfreeze()
        return new_coll
249
250
251
252
253
254

    @staticmethod
    def update_fp8_metas(state: Collection) -> Collection:
        """
        Update the FP8 metas
        """
255
        assert isinstance(state, (dict, FrozenDict))
256
        if FP8Helper.FP8_COLLECTION_NAME in state:
257
258
            frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state
            others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)
259
            fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
260
261
262
263
264
            new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})

            if not isinstance(state, FrozenDict):
                new_state = new_state.unfreeze()
            return new_state
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        return state

    @staticmethod
    def generate_fp8_max_array(num_of_meta):
        """
        Generate the FP8 max array
        """
        num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
        fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd
        fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd
        fp8_max_per_gemm = []
        for i in range(FP8Helper.NUM_META_PER_GEMM):
            val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
                else fp8_max_fwd
            fp8_max_per_gemm.append([val])
        fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
        return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)

    @staticmethod
    def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int]:
        """
        Obtain the index about FP8 metas by the given GEMM index.
        """
        input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
        kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
        grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
        return input_idx, kernel_idx, grad_idx

    @staticmethod
    @jax.jit
    def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
        fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
        num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
        num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
        for i in range(num_of_gemm):
            # flattern array is ordered in alphabetical order of collection names
            fp8_max_idx = i * num_of_meta_with_max
            fp8_amax_idx = fp8_max_idx + 1
            fp8_scale_idx = fp8_amax_idx + 1
            fp8_scale_inv_idx = fp8_scale_idx + 1

            fp8_max = fp8_meta_arrays[fp8_max_idx]
307
308
309
310
            if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
                amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=1, keepdims=True)
            else:
                amax = fp8_meta_arrays[fp8_amax_idx][:, 0:1]
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            scale = fp8_meta_arrays[fp8_scale_idx]

            exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
            sf = jnp.round(jnp.power(2, jnp.abs(exp)))
            sf = jnp.where(amax > 0.0, sf, scale)
            sf = jnp.where(jnp.isfinite(amax), sf, scale)
            scale = jnp.where(exp < 0, 1 / sf, sf)
            fp8_meta_arrays[fp8_scale_idx] = scale
            fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale

        return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)


@contextmanager
def fp8_autocast(enabled: bool = False,
                 fp8_recipe: Optional[DelayedScaling] = None,
                 sharding_resource: Optional[ShardingResource] = None) -> None:
328
    r"""
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    Context manager for FP8 usage.

    .. code-block:: python

        mesh_shape = (4, 2)
        dp_mesh_axis_name = 'data_parallel'
        tp_mesh_axis_name = 'tensor_parallel'
        devices = np.asarray(jax.devices()).reshape(*mesh_shape)

        with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
            sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name)

            with fp8_autocast(enabled=True, sharding_resource=sharding_resource):
                rules = extend_logical_axis_rules(tuple())
                transformer = TransformerLayer()

                with partitioning.axis_rules(rules):
                    pjit(transformer.init, ...)(...)

    .. note::
        We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
350
        :attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
351
        in recipe.DelayedScaling would be ignored, even if set.
352
353
354
355

    Parameters
    ----------
    enabled: bool, default = False
356
        Whether or not to enable fp8
357
    fp8_recipe: recipe.DelayedScaling, default = None
358
359
360
        Recipe used for FP8 training.
    sharding_resource: ShardingResource, default = None
        Specify the mesh axes for data and tensor parallelism to shard along.
361
362
363
364
365
366
367
368
369
370
371
        If set to None, then ShardingResource() would be created.
    """
    if fp8_recipe is None:
        fp8_recipe = DelayedScaling()

    if sharding_resource is None:
        sharding_resource = ShardingResource()

    try:
        with global_shard_guard(sharding_resource):
            if enabled:
372
373
374
                fp8_available, reason_for_no_fp8 = is_fp8_available()
                assert fp8_available, reason_for_no_fp8

375
376
377
378
                amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
                if fp8_recipe.amax_compute_algo == 'max':
                    amax_compute_algo = AmaxComputeAlgo.MAX

379
380
381
                FP8Helper.initialize(margin=fp8_recipe.margin,
                                     fp8_format=fp8_recipe.fp8_format,
                                     update_fp8meta_interval=fp8_recipe.interval,
382
383
                                     amax_history_len=fp8_recipe.amax_history_len,
                                     amax_compute_algo=amax_compute_algo)
384
385
386
            yield
    finally:
        FP8Helper.finalize()
387
388
389


# Function Wrappers
390
def update_collections(new: Collection, original: Collection) -> FrozenDict:
391
    r"""
392
    A helper to update Flax's Collection.
393

394
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    Parameters
    ----------
    new: Collection
        A collection that includes new data.
    original: Collection
        The base collection.

    Returns
    -------
    outputs : Collection
        The updated collection.
    """
    return FP8Helper.update_collections(new, original)


def update_fp8_metas(state: Collection) -> Collection:
    r"""
    Calculate new fp8 scales and its inverse via the followed formula

415
416
417
418
419
420
421
422
    .. code-block:: python

        exp = floor(log2(fp8_max / amax)) - margin
        sf = round(power(2, abs(exp)))
        sf = sf if amax > 0.0, else original_scale
        sf = sf if isfinite(amax), else original_scale)
        updated_scale = 1/sf if exp < 0, else sf
        updated_scale_inv = 1/updated_scale
423

424
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
425
426
427
428
429
430
431
432
433
434
435
436

    Parameters
    ----------
    state: Collection
        A collection that includes FP8 metas.

    Returns
    -------
    outputs : Collection
        The collection with updated FP8 metas.
    """
    return FP8Helper.update_fp8_metas(state)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460


def get_delayed_scaling():
    r"""
    Obtain an instance of  DelayedScaling which is set via fp8_autocast.

    .. note::
        We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
        :attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
        Other parameters in recipe.DelayedScaling would be returned as the default
        values.

    Returns
    -------
    delay_scaling : DelayedScaling
        an instance of  DelayedScaling which is set via fp8_autocast.
    """
    amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
                        else "most_recent"
    return DelayedScaling(margin=FP8Helper.MARGIN,
                          interval=FP8Helper.UPDATE_FP8META_INTERVAL,
                          fp8_format=FP8Helper.FP8_FORMAT,
                          amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
                          amax_compute_algo=amax_compute_algo)