fp8.py 9.65 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""

6
from contextlib import contextmanager
7
from typing import Tuple, Optional, Dict, Any, Union
8
9

import numpy as np
10
11
12

import paddle
import transformer_engine_paddle as tex
13
from transformer_engine.common.recipe import DelayedScaling, Format
14

15
from .constants import dist_group_type
Tian Zheng's avatar
Tian Zheng committed
16
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
17

18
# FP8 support
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
_is_fp8_available = None
_reason_for_no_fp8 = ""


def _check_fp8_support() -> Tuple[bool, str]:
    """Return if fp8 support is available"""

    # Check GPU arch
    arch = paddle.device.cuda.get_device_capability()
    if arch >= (9, 0):    # hopper and above
        return True, ""
    if arch < (8, 9):    # pre-ada
        return False, "Device compute capability 8.9 or higher required for FP8 execution."

    # Special handling for Ada
    if tex.get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if not paddle.version.cuda():
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    if tuple(int(v) for v in paddle.version.cuda().split(".")) < (12, 1):
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


def is_fp8_available() -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    global _is_fp8_available, _reason_for_no_fp8
    if _is_fp8_available is None:
        _is_fp8_available, _reason_for_no_fp8 = _check_fp8_support()
    return _is_fp8_available, _reason_for_no_fp8
49
50


51
52
class FP8State:
    """Stores FP8 state"""
53

54
    def __init__(self):
55
56
57
58
59
60
61
        self._fp8_enabled = False
        self._fp8_calibration = False
        self._fp8_recipe = None
        self._fp8_distributed_group = None
        self._is_first_fp8_module = False
        self._fp8_autocast_counter = 0
        self._fp8_autocast_depth = 0
Tian Zheng's avatar
Tian Zheng committed
62
        self._fp8_recompute_enabled = False
63
64
        self._fp8_fwd_buffer = FP8MetaFwdBuffer()
        self._fp8_bwd_buffer = FP8MetaBwdBuffer()
Tian Zheng's avatar
Tian Zheng committed
65
        self._fp8_recompute_buffer = FP8RecomputeBuffer()
66

67
68
    def is_fp8_enabled(self) -> bool:
        """Is FP8 enabled"""
69
        return self._fp8_enabled
70

71
72
    def is_fp8_calibration(self) -> bool:
        """Is FP8 calibration"""
73
        return self._fp8_calibration
74

75
76
    def get_fp8_recipe(self) -> DelayedScaling:
        """Return the fp8 recipe"""
77
        return self._fp8_recipe
78

79
80
81
82
83
84
    @staticmethod
    def get_default_fp8_recipe() -> DelayedScaling:
        """FP8 recipe if not provided by user
        Margin = 0, interval = 1, E4M3
        """
        return DelayedScaling()
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    def get_autocast_id(self) -> int:
        """Returns the number of times of entering the `fp8_autocast` context.
        as a unique ID for different training steps."""
        return self._fp8_autocast_counter

    def is_first_fp8_module(self):
        """Returns `True` only the first time when called multiple
        times from within the same `fp8_autocast` context.
        """
        tmp = self._is_first_fp8_module
        self._is_first_fp8_module = False
        return tmp

    def get_fp8_group(self) -> Union[dist_group_type, None]:
        """Return the fp8 group for scale/amax comm"""
        return self._fp8_distributed_group

    def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer:
        """Returns global fp8 forward buffer."""
        return self._fp8_fwd_buffer

    def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer:
        """Returns global fp8 backward buffer."""
        return self._fp8_bwd_buffer

Tian Zheng's avatar
Tian Zheng committed
111
112
113
114
115
116
117
118
    def is_fp8_recompute_enabled(self) -> bool:
        """Is FP8 recompute enabled"""
        return self._fp8_recompute_enabled

    def get_fp8_recompute_buffer(self) -> FP8RecomputeBuffer:
        """Returns global fp8 recompute buffer."""
        return self._fp8_recompute_buffer

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    def enter(
        self,
        enabled: bool,
        calibrating: bool,
        fp8_recipe: Optional[DelayedScaling],
        fp8_group: Optional[dist_group_type],
    ) -> None:
        """Called when entering 'fp8_autocast'"""
        self.saved_states = (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe,
                             self._fp8_distributed_group, self._is_first_fp8_module)

        self._fp8_enabled = enabled
        self._fp8_calibration = calibrating
        self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
        self._fp8_distributed_group = fp8_group

        if self._fp8_autocast_depth == 0:
            self._is_first_fp8_module = True
            self._fp8_autocast_counter += 1
        self._fp8_autocast_depth += 1

    def exit(self):
        """Called when exiting 'fp8_autocast'"""
        # Restore saved states
        (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, self._fp8_distributed_group,
         self._is_first_fp8_module) = self.saved_states

        self._fp8_autocast_depth -= 1

        if self._fp8_autocast_depth == 0:
            self._fp8_fwd_buffer.finalize()

151
152
153
154
155
156
157

_global_fp8_state = FP8State()


def get_global_fp8_state() -> FP8State:
    """Get global fp8 state"""
    return _global_fp8_state
158
159
160
161
162
163
164


@contextmanager
def fp8_autocast(
    enabled: bool = False,
    calibrating: bool = False,
    fp8_recipe: Optional[DelayedScaling] = None,
165
    fp8_group: Optional[dist_group_type] = None,
166
167
168
169
170
) -> None:
    """
    Context manager for FP8 usage.
    """
    try:
171
        _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group)
172
173
174
175
176
177

        if enabled:
            fp8_available, reason_for_no_fp8 = is_fp8_available()
            assert fp8_available, reason_for_no_fp8
        yield
    finally:
178
        _global_fp8_state.exit()
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199


def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (fp8_recipe.fp8_format == Format.HYBRID
                                                and fprop_tensor):
        return tex.DType.kFloat8E4M3
    return tex.DType.kFloat8E5M2


def amax_and_scale_update(
    fp8_meta: Dict[str, Any],
    fwd_update: bool,
) -> None:
    """Updates fp8 amaxes/scales for fwd | bwd."""
    amax_compute = fp8_meta["recipe"].amax_compute_algo
    sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
    fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
    fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"

    if not callable(amax_compute) and sf_compute is None:
200
201
202
203
204
205
        tex.amax_and_scale_update_inplace(_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
                                          _scale=fp8_meta[fp8_meta_tensor_key].scale,
                                          _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
                                          fp8_max=fp8_meta[fp8_max_key],
                                          margin=float(fp8_meta["recipe"].margin),
                                          amax_compute=amax_compute)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    else:
        raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' "
                         "amax_compute_algo and default scaling_factor_compute_algo at this "
                         "moment.")


class FP8TensorMeta():
    """Holds FP8 scaling and amax history for FP8 layers"""

    def __init__(self, is_forward: bool):
        self.scale = paddle.Tensor()
        self.scale_inv = paddle.Tensor()
        self.amax_history = paddle.Tensor()
        self.is_initialized = False
        self.is_forward = is_forward

    def prepare(self, num_gemms: bool, amax_history_len: int) -> None:
        """Prepare scales and amax tensors. It is called during fprop in each iteration.
        If the meta tensors are not initialized yet, initialization is performed. If already
        initialized, resize the meta tensors if amax_history_len has changed."""

        if self.is_initialized:
            # Handle changed amax history size.
            curr_len = self.amax_history.shape[0]
            num_fp8_tensors = self.amax_history.shape[1]
            if amax_history_len < curr_len:
232
                self.amax_history = self.amax_history[:amax_history_len]
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            elif amax_history_len > curr_len:
                extra_rows = amax_history_len - curr_len
                self.amax_history = paddle.concat([
                    self.amax_history,
                    paddle.zeros((extra_rows, num_fp8_tensors), dtype='float32')
                ],
                                                  axis=0)
            return

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
        num_fp8_tensors = (num_gemms * 3 if self.is_forward else num_gemms * 2)

        self.scale = paddle.ones(num_fp8_tensors, dtype='float32')
        self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
        self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32')
        self.is_initialized = True

    def to_numpy(self):
        """Convert FP8 meta tensors to numpy."""
        assert self.is_initialized, "FP8TensorMeta is not initialized yet."
        return {
            'scale': self.scale.numpy(),
            'scale_inv': self.scale_inv.numpy(),
            'amax_history': self.amax_history.numpy(),
        }

    def from_numpy(self, data: Dict[str, np.array]):
        """Set FP8 meta tensors from numpy"""
        self.scale = paddle.to_tensor(data['scale'])
        self.scale_inv = paddle.to_tensor(data['scale_inv'])
        self.amax_history = paddle.to_tensor(data['amax_history'])
        self.is_initialized = True