recipe.py 5.47 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""This module provides predefined FP8 recipes."""
from __future__ import annotations
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass


class _FormatHelper(NamedTuple):
    """
    Stores max FP8 values for fprop and bprop a `Format`.
    """

    max_fwd: float
    max_bwd: float


class Format(Enum):
    """
    Supported FP8 formats.

    Values
    ------
    E4M3 :
          All FP8 tensors are in e4m3 format
    E5M2 :
          All FP8 tensors are in e5m2 format
    HYBRID :
            FP8 tensors in the forward pass are in e4m3 format,
            FP8 tensors in the backward pass are in e5m2 format
    """

    E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
    E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
    HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)


class _OverrideLinearPrecision(NamedTuple):
    """
    Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
    GEMMs in higher precision when using FP8.
    """

    fprop: bool = False
    dgrad: bool = False
    wgrad: bool = False


@dataclass()
class DelayedScaling:
    """
    Use the delayed scaling factor strategy.
    Use scale factor from previous iteration,
    recompute once every `interval`, and record
    amax history of `amax_history_len` steps.

    Parameters
    ----------
    margin : int, default = 0
            Margin for the scaling factor computation.
    interval : int, default = 1
              Controls how often the scaling factor is recomputed.
66
    fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Przemek Tredak's avatar
Przemek Tredak committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
                Controls the FP8 data format used during forward and backward
                pass.
    amax_history_len : int, default = 1
                      The length of the amax history window used for
                      scaling factor computation.
    amax_compute_algo : {'max', 'most_recent', Callable}, default = 'most_recent'
                       Algorithm used for choosing the `amax` value for the
                       scaling factor computation. There are 2 predefined
                       choices: `max` chooses the largest `amax` in the history
                       window, while `most_recent` always chooses the most recently
                       seen value. Alternatively, one may pass a function of the
                       signature:

                       .. code-block:: python

                         def amax_compute(amax_history: Tensor) -> Tensor

                       where `Tensor` is a framework tensor type.
    scaling_factor_compute_algo : Callable, default = None
                                 Algorithm used for computing the new scaling
                                 factor based on the value of `amax`. It should
                                 be a function of the signature:

                                 .. code-block:: python

                                   def scaling_factor_compute(amax: Tensor,
                                                              old_scaling_factor: Tensor,
                                                              fp8_max: Tensor,
                                                              recipe: DelayedScaling) -> Tensor

                                 where `Tensor` is a framework tensor type.
    override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
                              Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
                              GEMMs (respectively) in higher precision when using FP8.
101
102
103
104
105
106
107
108
    reduce_amax: bool, default = `True`
                By default, if `torch.distributed` is initialized, the `amax` value for FP8
                tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
                call). This keeps the amaxes and scaling factors synced across the given
                distributed group. If set to `False`, this reduction is skipped and every
                GPU maintains local amaxes and scaling factors. To ensure results are
                numerically identical across checkpointing boundaries in this case, all
                ranks must checkpoint in order to store the local tensors.
Przemek Tredak's avatar
Przemek Tredak committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    Notes
    -----
    * By default (when `scaling_factor_compute_algo` is left as `None`) the scaling
      factor is computed from the final `amax` value using the formula:

      .. code-block:: python

          FP8_MAX = maximum_representable_value(fp8_format)
          exp = get_exponent(FP8_MAX / amax) - margin
          new_scaling_factor = 2.0 ^ exp

    * The scaling factor should always be a power of 2 to not introduce numerical
      error during the conversion from FP8 to higher precision format.
    """

    margin: int = 0
    interval: int = 1
    fp8_format: Format = Format.HYBRID
    amax_history_len: int = 1
    amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent"
    override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
    scaling_factor_compute_algo: Optional[Callable] = None
132
    reduce_amax: bool = True
Przemek Tredak's avatar
Przemek Tredak committed
133
134
135
136
137
138
139

    def __post_init__(self) -> None:
        assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
        assert self.override_linear_precision in (
            (False, False, False),
            (False, False, True),
        ), "Only wgrad GEMM override is currently supported."