fp8_buffer.py 12.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""FP8 meta buffer for FP8 amax reduction"""

from abc import ABC, abstractmethod
Tian Zheng's avatar
Tian Zheng committed
7
from collections import deque
8
9
10
11
12
13
from functools import partial
import os
from typing import Dict, Any, List, Union

import numpy as np
import paddle
14
from transformer_engine import transformer_engine_paddle as tex
15

Tian Zheng's avatar
Tian Zheng committed
16
from .constants import dist_group_type, RecomputeFunctionNames
17
18
19
20
21
22
23
24


class FP8MetaBufferBase(ABC):
    """
    A global buffer that holds FP8 meta for reduction across trainers.
    """

    def __init__(self):
25
        self._global_amax = {}
26
27
28
        self._buffer_delete_key = None
        self._amax_reduce_wait_func = None
        self._dp_amax_reduce_interval = None
29
30
        self._contiguous_amax = None
        self._use_cudagraph = False
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        self._dp_amax_reduce_idx = 0

    @staticmethod
    @abstractmethod
    def _get_meta_tensor_key():
        """Returns scaling key in `fp8_meta`."""

    @staticmethod
    @abstractmethod
    def _get_buffer_position_key():
        """Returns module position key in `fp8_meta`."""

    @staticmethod
    @abstractmethod
    def _get_autocast_key():
        """Returns autocast id key in `fp8_meta`."""

    def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str:
49
        """Return a key in `_global_amax` for the AMAX storage."""
50
51
52
53
        return f"AMAX_{fp8_meta[self._get_autocast_key()]}"

    def _execute_deletion(self) -> None:
        """Delete the key from global amax buffer."""
54
55
        if self._buffer_delete_key is not None and self._buffer_delete_key in self._global_amax:
            del self._global_amax[self._buffer_delete_key]
56
57
58
59
60
61
62
63
64
65
66

    def _wait_handle_and_split(
        self,
        contiguous_amax: paddle.Tensor,
        chunk_sizes: List[int],
        amax_buffer_key: str,
        wait_handle: Union[bool, None],
    ) -> None:
        """Wait for amax reduction to finish and then copy reduced amax to buffer"""
        if wait_handle is not None:
            wait_handle.wait()
67
68
69
70
71
72
        if self._use_cudagraph:
            splited_list = list(contiguous_amax.split(chunk_sizes))
            for amax, split in zip(self._global_amax[amax_buffer_key], splited_list):
                amax.copy_(split, False)
        else:
            self._global_amax[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    def _global_amax_reduction(
        self,
        fp8_meta: Dict[str, Any],
        tp_group: dist_group_type,
        tp_size: int,
    ) -> None:
        """Concatenate, reduce, and split amaxes in the global buffer."""

        def _reduce_tensor_across_group_op_max(tensor, group, sync_op):
            if paddle.distributed.is_initialized():
                wait_handle = paddle.distributed.all_reduce(
                    tensor,
                    op=paddle.distributed.ReduceOp.MAX,
                    group=group,
                    sync_op=sync_op,
                )
                return wait_handle
            return None

        amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
        # Key already deleted.
95
        if amax_buffer_key not in self._global_amax:
96
97
98
99
100
101
102
            return None

        # Reduce AMAX in DP-domain at an interval.
        if self._dp_amax_reduce_interval is None:
            self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))

        tp_amax_reduce = False
103
        reduce_group = -1  # Set value that will raise error if not set. `None` is a valid group.
104
105
106
107
108
109
110
111
112
113
114
115
        if self._dp_amax_reduce_idx == 0:
            reduce_group = fp8_meta["fp8_group"]
        else:
            tp_amax_reduce = True
        self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval

        if tp_amax_reduce:
            if tp_size > 1:
                reduce_group = tp_group
            else:
                return None

116
117
118
119
120
121
122
123
124
125
126
        chunk_sizes = [x.shape[0] for x in self._global_amax[amax_buffer_key]]
        if self._use_cudagraph:
            # we need to ensure the _contiguous_amax is address-stable under cudagraph
            if self._contiguous_amax is None:
                self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key])
            else:
                self._contiguous_amax.copy_(
                    paddle.concat(self._global_amax[amax_buffer_key]), False
                )
        else:
            self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key])
127
128

        wait_handle = _reduce_tensor_across_group_op_max(
129
            self._contiguous_amax,
130
131
132
133
            reduce_group,
            not fp8_meta["async_amax_reduction"],
        )

134
135
136
137
138
        if wait_handle is not None and self._use_cudagraph:
            # we need to ensure record/wait does not cross the boundary of the graph
            wait_handle.wait()
            wait_handle = None

139
140
        return partial(
            self._wait_handle_and_split,
141
            self._contiguous_amax,
142
143
144
145
146
147
148
149
150
151
152
            chunk_sizes,
            amax_buffer_key,
            wait_handle,
        )

    def add_amax(self, fp8_meta: Dict[str, Any]) -> None:
        """Append `amax_history` to global buffer."""
        buffer_key = self._get_amax_buffer_key(fp8_meta)
        fp8_meta_tensor_key = self._get_meta_tensor_key()
        buffer_position_key = self._get_buffer_position_key()

153
154
        if buffer_key not in self._global_amax:
            self._global_amax[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
155
        else:
156
            self._global_amax[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
157
158

        if buffer_position_key not in fp8_meta:
159
            fp8_meta[buffer_position_key] = len(self._global_amax[buffer_key]) - 1
160
161

        # Catch incorrect fp8_autocast usage.
162
        assert fp8_meta[buffer_position_key] == len(self._global_amax[buffer_key]) - 1, (
163
164
165
            "Same module is being invoked more than once inside an `fp8_autocast` "
            "region when using FP8 with amax reduction. This behavior is currently "
            "unsupported. For more details and correct usage, please see "
166
            "https://github.com/NVIDIA/TransformerEngine/pull/93."
167
        )
168
169
170
171
172
173
174
175
176

    def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None:
        """Populate current amax with the correct location from buffer."""
        fp8_meta_tensor_key = self._get_meta_tensor_key()
        buffer_position_key = self._get_buffer_position_key()
        if buffer_position_key not in fp8_meta:
            return

        amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
177
        assert amax_buffer_key in self._global_amax, "TE internal error."
178

179
180
181
        # Copy amax to amax_history[0]
        tex.update_latest_amax_history_inplace(
            _history=fp8_meta[fp8_meta_tensor_key].amax_history,
182
            amax=self._global_amax[amax_buffer_key][fp8_meta[buffer_position_key]],
183
        )
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None:
        """Delete this amax key from global buffer during autocast end."""
        if self._get_autocast_key() not in fp8_meta:
            return
        self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta)

    def get_amax_reduce_handle(self) -> Union[bool, None]:
        """Return AMAX reduction wait handle."""
        return self._amax_reduce_handle

    def wait(self) -> None:
        """Wait for reduced amax to be available in buffer."""
        if self._amax_reduce_wait_func is not None:
198
            self._amax_reduce_wait_func()  # pylint: disable=not-callable
199
200
201
202
203
            self._amax_reduce_wait_func = None

    def to_numpy(self) -> Dict[str, List[np.array]]:
        """Convert to numpy arrays"""
        out = {}
204
        for k, v in self._global_amax.items():
205
206
207
208
209
210
            out[k] = [tensor.numpy() for tensor in v]
        return out

    def from_numpy(self, buffer: Dict[str, np.array]) -> None:
        """Set buffer values from numpy arrays"""
        for k, v in buffer.items():
211
212
213
214
215
            self._global_amax[k] = [paddle.to_tensor(arr) for arr in v]

    def enable_cudagraph(self):
        """Enable CUDA Graphs."""
        self._use_cudagraph = True
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254


class FP8MetaFwdBuffer(FP8MetaBufferBase):
    """FP8Meta Buffer for forward"""

    @staticmethod
    def _get_meta_tensor_key() -> str:
        """Returns scaling key in `fp8_meta`."""
        return "scaling_fwd"

    @staticmethod
    def _get_buffer_position_key() -> str:
        """Returns module position key in `fp8_meta`."""
        return "global_fp8_buffer_pos_fwd"

    @staticmethod
    def _get_autocast_key() -> str:
        """Returns module position key in `fp8_meta`."""
        return "autocast_id_fwd"

    def set_for_amax_reduction(
        self,
        fp8_meta: Dict[str, Any],
        tp_group: dist_group_type,
        tp_size: int,
    ) -> None:
        """Sets up the function to call during autocast exit."""
        self._amax_global_reduce_func = partial(
            self._global_amax_reduction,
            fp8_meta,
            tp_group,
            tp_size,
        )

    def finalize(self) -> None:
        """
        Called at FP8 autocast end.
        Performs AMAX reduction and delete unused buffer entries.
        """
255
        if hasattr(self, "_amax_global_reduce_func") and callable(self._amax_global_reduce_func):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
            self._amax_reduce_wait_func = self._amax_global_reduce_func()
        self._execute_deletion()


class FP8MetaBwdBuffer(FP8MetaBufferBase):
    """FP8Meta Buffer for backward"""

    @staticmethod
    def _get_meta_tensor_key() -> str:
        """Returns scaling key in `fp8_meta`."""
        return "scaling_bwd"

    @staticmethod
    def _get_buffer_position_key() -> str:
        """Returns module position key in `fp8_meta`."""
        return "global_fp8_buffer_pos_bwd"

    @staticmethod
    def _get_autocast_key() -> str:
        """Returns module position key in `fp8_meta`."""
        return "autocast_id_bwd"

    def finalize(
        self,
        fp8_meta: Dict[str, Any],
        tp_group: dist_group_type,
        tp_size: int,
    ) -> None:
        """
        Called at FP8 autocast end in backward.
        Performs AMAX reduction and delete unused buffer entries.
        """
288
289
290
        self._amax_reduce_wait_func = self._global_amax_reduction(
            fp8_meta, tp_group, tp_size
        )  # _wait_handle_and_split
291
        self._execute_deletion()
Tian Zheng's avatar
Tian Zheng committed
292
293
294
295
296
297


class FP8RecomputeBuffer:
    """Buffer used to hold FP8 meta tensors for recompute"""

    def __init__(self):
298
        self._global_amax = []
Tian Zheng's avatar
Tian Zheng committed
299
300
301
302

    @staticmethod
    def get_buffer_position_key():
        """Returns the key (in fp8_meta) for recompute buffer position"""
303
        return "recompute_buffer_pos"
Tian Zheng's avatar
Tian Zheng committed
304
305
306
307
308
309
310
311
312
313
314
315

    def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
        """Stash the scaling factors and amaxes for recompute"""
        buffer_position_key = self.get_buffer_position_key()

        to_copy = [
            fp8_meta["scaling_fwd"].amax_history.clone(),
            fp8_meta["scaling_fwd"].scale.clone(),
            fp8_meta["scaling_fwd"].scale_inv.clone(),
        ]

        if buffer_position_key in fp8_meta:
316
            self._global_amax[fp8_meta[buffer_position_key]].append(to_copy)
Tian Zheng's avatar
Tian Zheng committed
317
        else:
318
319
320
            self._global_amax.append(deque())
            self._global_amax[-1].append(to_copy)
            fp8_meta[buffer_position_key] = len(self._global_amax) - 1
Tian Zheng's avatar
Tian Zheng committed
321
322
323
324
325
326
327
328
329
330

    def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
        """Switch to the previously saved scaling factors and amaxes"""
        # Store updated amaxes and scales from phase 1 post forward.
        fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
        fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
        fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv

        # Retrieve stashed amaxes and scales from phase 1 pre forward.
        buffer_position_key = self.get_buffer_position_key()
331
        stashed_fp8_meta = self._global_amax[fp8_meta[buffer_position_key]].popleft()
Tian Zheng's avatar
Tian Zheng committed
332
333
334
335
336
337
338
339
340

        # Replace amaxes and scales with stashed values for phase 2 forward
        fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
        fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
        fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]

    @staticmethod
    def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
        """Restore latest scaling factors and amaxes after recompute forward run."""
341
342
343
344
345
        assert "updated_amax_history_fwd" in fp8_meta, (
            "Recompute internal error."
            " If you are not using recompute, please check if"
            " the forward function is called from one of these functions: "
            f"{RecomputeFunctionNames}. If so, consider change the function name "
Tian Zheng's avatar
Tian Zheng committed
346
            "or set NVTE_DISABLE_RECOMPUTE=1."
347
        )
Tian Zheng's avatar
Tian Zheng committed
348
349
350
        fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
        fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
        fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]