fp8.py 17.7 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""FP8 utilies for TransformerEngine"""
from contextlib import contextmanager
7
from collections import deque
8
from typing import Callable, List, Optional, Dict, Any, Tuple, Union, Deque
Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
14
15
16

import torch
import transformer_engine_extensions as tex
from transformer_engine.common.recipe import DelayedScaling, Format

from .constants import dist_group_type

_FP8_ENABLED = False
schetlur-nv's avatar
schetlur-nv committed
17
_FP8_CALIBRATION = False
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
21
22
23
24
_FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_COUNTER = 0
_FP8_CURRENT_CONTEXT_ID = 0
_FP8_AUTOCAST_DEPTH = 0
_global_fp8_buffer = {}
25
_fp8_tensors_recompute_buffer = []
Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None


def get_meta_tensor_key(forward: bool = True) -> str:
    """Returns scaling key in `fp8_meta`."""
    if forward:
        return "scaling_fwd"
    return "scaling_bwd"


def get_buffer_position_key(forward: bool = True) -> str:
    """Returns module position key in `fp8_meta`."""
    if forward:
        return "global_fp8_buffer_pos_fwd"
    return "global_fp8_buffer_pos_bwd"


def get_autocast_key(forward: bool = True) -> str:
    """Returns module position key in `fp8_meta`."""
    if forward:
        return "autocast_id_fwd"
    return "autocast_id_bwd"


def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
    """Returns global fp8 buffer."""
    return _global_fp8_buffer


def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
    """Sets global fp8 buffer."""
    global _global_fp8_buffer

    # Map all tensors back to GPU.
    for k, v in buffer.items():
        buffer[k] = [tensor.cuda() for tensor in v]

    _global_fp8_buffer = buffer


68
69
70
71
72
def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
    """Returns global fp8 recompute buffer."""
    return _fp8_tensors_recompute_buffer


73
def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None:
74
75
76
77
78
    """Sets global fp8 recompute buffer."""
    global _fp8_tensors_recompute_buffer

    # Map all tensors back to GPU.
    for index, deck in enumerate(buffer):
79
        buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck])
80
81
82
83

    _fp8_tensors_recompute_buffer = buffer


Przemek Tredak's avatar
Przemek Tredak committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def setup_amax_forward_global_reduce_func(f: Callable) -> None:
    """Sets up the function to call during autocast exit."""
    global _amax_forward_global_reduce_func
    _amax_forward_global_reduce_func = f


def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
    """Return a key in `_global_fp8_buffer` for the AMAX storage."""
    if forward:
        return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
    return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"


def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) -> None:
    """Append 1D tensor `amax` to global buffer."""
    global _global_fp8_buffer
    buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
    fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
    buffer_position_key = get_buffer_position_key(forward=forward)

    if buffer_key not in _global_fp8_buffer:
        _global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
    else:
        _global_fp8_buffer[buffer_key].append(
            fp8_meta[fp8_meta_tensor_key].amax_history[0]
        )

    if buffer_position_key not in fp8_meta:
        fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1


115
116
117
118
119
120
121
def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
    """Copy the scaling factors and amaxes for recompute forward phase
    to ensure both forward steps are numerically same.
    """
    global _fp8_tensors_recompute_buffer
    buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"

122
    to_copy = [
123
124
125
        fp8_meta["scaling_fwd"].amax_history.clone(),
        fp8_meta["scaling_fwd"].scale.clone(),
        fp8_meta["scaling_fwd"].scale_inv.clone(),
126
    ]
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

    if buffer_position_key in fp8_meta:
        _fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
    else:
        if len(_fp8_tensors_recompute_buffer) == 0:
            _fp8_tensors_recompute_buffer = [deque()]
        else:
            _fp8_tensors_recompute_buffer.append(deque())
        _fp8_tensors_recompute_buffer[-1].append(to_copy)
        fp8_meta[buffer_position_key] = len(_fp8_tensors_recompute_buffer) - 1


def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
    """Switch to the copied scaling factors and amaxes from phase
    1 forward for indentical numerical outputs.
    """

    # Store updated amaxes and scales from phase 1 post forward.
    fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
    fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
    fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv

    # Retrieve stashed amaxes and scales from phase 1 pre forward.
    buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
    stashed_fp8_meta = _fp8_tensors_recompute_buffer[
        fp8_meta[buffer_position_key]
    ].popleft()

    # Replace amaxes and scales with stashed values for phase 2 forward
    fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
    fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
    fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]


def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
    """Restore latest scaling factors and amaxes after recompute forward run."""
    fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
    fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
    fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]


Przemek Tredak's avatar
Przemek Tredak committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def copy_amax_from_global_buffer(
    fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
    """Populate current amax with the correct location from buffer."""
    fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
    buffer_position_key = get_buffer_position_key(forward=forward)
    if buffer_position_key not in fp8_meta:
        return
    amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
    fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][
        fp8_meta[buffer_position_key]
    ]


def set_amax_buffer_key_deletion(
    fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
    """Delete this amax key from global buffer during autocast end."""
    if get_autocast_key(forward=forward) not in fp8_meta:
        return
    global _buffer_delete_key_fwd, _buffer_delete_key_bwd
    if forward:
        _buffer_delete_key_fwd = get_amax_buffer_key(fp8_meta, forward=forward)
    else:
        _buffer_delete_key_bwd = get_amax_buffer_key(fp8_meta, forward=forward)


def get_default_fp8_recipe() -> DelayedScaling:
    """FP8 recipe if not provided by user
    Margin = 0, interval = 1, E4M3
    """
    return DelayedScaling()


@contextmanager
def fp8_autocast(
    enabled: bool = False,
schetlur-nv's avatar
schetlur-nv committed
205
    calibrating: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    fp8_recipe: Optional[DelayedScaling] = None,
    fp8_group: Optional[dist_group_type] = None,
) -> None:
    """
    Context manager for FP8 usage.

    .. code-block:: python

        with fp8_autocast(enabled=True):
            out = model(inp)

    .. note::

        Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
        with shapes where both dimensions are divisible by 16. In terms of the input to the full
        Transformer network, this typically requires padding sequence length to be multiple of 16.

    Parameters
    ----------
    enabled: bool, default = `False`
             whether or not to enable fp8
227
228
229
230
231
    calibrating: bool, default = `False`
                 calibration mode allows collecting statistics such as amax and scale
                 data of fp8 tensors even when executing without fp8 enabled. This is
                 useful for saving an inference ready fp8 checkpoint while training
                 using a higher precision.
Przemek Tredak's avatar
Przemek Tredak committed
232
233
234
235
236
237
238
    fp8_recipe: recipe.DelayedScaling, default = `None`
                recipe used for FP8 training.
    fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
               distributed group over which amaxes for the fp8 tensors
               are reduced at the end of each training step.
    """

schetlur-nv's avatar
schetlur-nv committed
239
    global _FP8_ENABLED, _FP8_CALIBRATION,  _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
Przemek Tredak's avatar
Przemek Tredak committed
240
241
    global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
    global _global_fp8_buffer, _buffer_delete_key_fwd
schetlur-nv's avatar
schetlur-nv committed
242
    fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
Przemek Tredak's avatar
Przemek Tredak committed
243
244
    try:
        _FP8_ENABLED = enabled
schetlur-nv's avatar
schetlur-nv committed
245
        _FP8_CALIBRATION = calibrating
Przemek Tredak's avatar
Przemek Tredak committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        _FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
        _FP8_DISTRIBUTED_GROUP = fp8_group

        if _FP8_AUTOCAST_DEPTH == 0:
            _IS_FIRST_FP8_MODULE = True
            _FP8_AUTOCAST_COUNTER += 1
        _FP8_AUTOCAST_DEPTH += 1

        if enabled:
            assert (
                torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9
            ), "Device compute capability 9.x required for FP8 execution."
        yield
    finally:
schetlur-nv's avatar
schetlur-nv committed
260
        _FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
Przemek Tredak's avatar
Przemek Tredak committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        _IS_FIRST_FP8_MODULE = False
        _FP8_AUTOCAST_DEPTH -= 1

        if _FP8_AUTOCAST_DEPTH == 0:
            if callable(_amax_forward_global_reduce_func):
                _amax_forward_global_reduce_func()
            delete_key_from_amax_buffer(forward=True)


def get_fp8_context_id() -> int:
    """Returns an ID for the current FP8 context."""
    return _FP8_CURRENT_CONTEXT_ID


def set_fp8_context_id(ctx_id: int) -> None:
    """Sets the current FP8 context."""
    global _FP8_CURRENT_CONTEXT_ID
    _FP8_CURRENT_CONTEXT_ID = ctx_id


def new_fp8_context_id() -> int:
    """Returns global autocast counter as a proxy to be used
    as the autocast ID for FP8 modules.
    """
    return _FP8_AUTOCAST_COUNTER


def is_fp8_enabled() -> bool:
    """Is FP8 enabled"""
    return _FP8_ENABLED

schetlur-nv's avatar
schetlur-nv committed
292
293
294
def is_fp8_calibration() -> bool:
    """Is FP8 calibration"""
    return _FP8_CALIBRATION
Przemek Tredak's avatar
Przemek Tredak committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

def is_first_fp8_module():
    """Returns `True` only the first time when called multiple
    times from within the same `fp8_autocast` context.
    """
    global _IS_FIRST_FP8_MODULE
    tmp = _IS_FIRST_FP8_MODULE
    _IS_FIRST_FP8_MODULE = False
    return tmp


def get_fp8_recipe() -> DelayedScaling:
    """Return the fp8 recipe"""
    return _FP8_RECIPE


def get_fp8_group() -> Union[dist_group_type, None]:
    """Return the fp8 group for scale/amax comm"""
    return _FP8_DISTRIBUTED_GROUP


def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
    """Update amax history and set next amax to zero."""
    amax_history = torch.roll(amax_history, -1, 0)
    amax_history[0].fill_(0.0)
    return amax_history


@torch.jit.script
def _default_get_amax(
    amax_history: torch.Tensor,
    amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Default function to obtain amax from history."""
    if amax_compute_algo == "max":
        amax = torch.max(amax_history, dim=0).values
    else:  # amax_compute_algo == "most_recent"
        amax = amax_history[0]

    amax_history = update_amax_history(amax_history)
    return amax_history, amax


@torch.jit.script
def _default_sf_compute(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    margin: int,
) -> torch.Tensor:
    """Default function to convert amax to scaling factor."""
    exp = torch.floor(torch.log2(fp8_max / amax)) - margin
    sf = torch.round(torch.pow(2, torch.abs(exp)))
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)
    sf = torch.where(exp < 0, 1 / sf, sf)

    return sf


355
356
357
358
359
360
361
362
363
364
365
366
367
@torch.jit.script
def _compute_scaling_factor_inverse(
    scale: torch.Tensor,
    scale_inv: torch.Tensor,
    non_weight_mask: torch.Tensor,
    update_weight_scale_inv: bool,
) -> torch.Tensor:
    """Compute inverse of scaling factor."""
    if update_weight_scale_inv:
        return 1.0 / scale
    return torch.where(non_weight_mask, 1.0 / scale, scale_inv)


Przemek Tredak's avatar
Przemek Tredak committed
368
369
370
371
@torch.jit.script
def fused_amax_and_scale_update(
    amax_history: torch.Tensor,
    scale: torch.Tensor,
372
    scale_inv: torch.Tensor,
Przemek Tredak's avatar
Przemek Tredak committed
373
374
375
    fp8_max: float,
    margin: int,
    amax_compute_algo: str,
376
377
378
    non_weight_mask: torch.Tensor,
    update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Przemek Tredak's avatar
Przemek Tredak committed
379
380
381
382
383
384
385
386
387
    """Amax to scale conversion."""

    # Get amax from history.
    amax_history, amax = _default_get_amax(
        amax_history,
        amax_compute_algo,
    )

    # Calculate new scaling factor.
388
    scale = _default_sf_compute(
Przemek Tredak's avatar
Przemek Tredak committed
389
390
391
392
393
394
        amax,
        scale,
        fp8_max,
        margin,
    )

395
396
397
398
399
400
401
402
403
404
    # Calculate new inverse of scaling factor.
    scale_inv = _compute_scaling_factor_inverse(
        scale,
        scale_inv,
        non_weight_mask,
        update_weight_scale_inv,
    )

    return amax_history, scale, scale_inv

Przemek Tredak's avatar
Przemek Tredak committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

def _compute_amax(
    amax_history: torch.Tensor,
    recipe: DelayedScaling,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Obtain the amax from the history."""

    if callable(recipe.amax_compute_algo):
        amax = recipe.amax_compute_algo(amax_history)
        amax_history = update_amax_history(amax_history)
        return amax_history, amax
    return _default_get_amax(
        amax_history,
        recipe.amax_compute_algo,
    )


def _compute_scaling_factor(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
) -> torch.Tensor:
    """Convert amax to scaling factor."""

    if recipe.scaling_factor_compute_algo is None:
        return _default_sf_compute(
            amax,
            scale,
            fp8_max,
            recipe.margin,
        )
    return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)


def amax_and_scale_update(
    fp8_meta: Dict[str, Any],
    fwd_update: bool,
443
    update_weight_scale_inv: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
444
445
446
447
448
449
450
451
452
453
454
) -> None:
    """Updates fp8 amaxes/scales for fwd | bwd."""
    amax_compute = fp8_meta["recipe"].amax_compute_algo
    sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
    fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
    fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"

    if not callable(amax_compute) and sf_compute is None:
        (
            fp8_meta[fp8_meta_tensor_key].amax_history,
            fp8_meta[fp8_meta_tensor_key].scale,
455
            fp8_meta[fp8_meta_tensor_key].scale_inv,
Przemek Tredak's avatar
Przemek Tredak committed
456
457
458
        ) = fused_amax_and_scale_update(
            fp8_meta[fp8_meta_tensor_key].amax_history,
            fp8_meta[fp8_meta_tensor_key].scale,
459
            fp8_meta[fp8_meta_tensor_key].scale_inv,
Przemek Tredak's avatar
Przemek Tredak committed
460
461
462
            fp8_meta[fp8_max_key],
            fp8_meta["recipe"].margin,
            fp8_meta["recipe"].amax_compute_algo,
463
464
            fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
            update_weight_scale_inv,
Przemek Tredak's avatar
Przemek Tredak committed
465
466
467
468
469
470
471
472
473
474
475
476
        )
    else:
        fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
            fp8_meta[fp8_meta_tensor_key].amax_history,
            fp8_meta["recipe"],
        )
        fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor(
            amax,
            fp8_meta[fp8_meta_tensor_key].scale,
            fp8_meta[fp8_max_key],
            fp8_meta["recipe"],
        )
477
478
479
480
481
482
        fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse(
            fp8_meta[fp8_meta_tensor_key].scale,
            fp8_meta[fp8_meta_tensor_key].scale_inv,
            fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
            update_weight_scale_inv,
        )
Przemek Tredak's avatar
Przemek Tredak committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548


def get_fp8_te_dtype(
    fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return tex.DType.kFloat8E4M3
    return tex.DType.kFloat8E5M2


def reduce_tensor_across_group_op_max(
    tensor: torch.Tensor, group: dist_group_type
) -> None:
    """Reduce tensor across given group."""
    if torch.distributed.is_initialized():
        torch.distributed.all_reduce(
            tensor,
            op=torch.distributed.ReduceOp.MAX,
            group=group,
            async_op=False,
        )


def global_amax_reduction(
    fp8_meta: Dict[str, Any],
    reduce_amax_across_tp_group: bool = False,
    tp_group: Optional[dist_group_type] = None,
    forward: bool = True,
) -> None:
    """Concatenate, reduce, and split amaxes in the global buffer."""
    global _global_fp8_buffer
    amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)

    # Key already deleted.
    if amax_buffer_key not in _global_fp8_buffer:
        return

    chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
    contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])

    reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
    if reduce_amax_across_tp_group:
        reduce_tensor_across_group_op_max(contiguous_amax, tp_group)

    _global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))


def delete_key_from_amax_buffer(forward: bool = True) -> None:
    """Delete the key from global amax buffer."""

    global _global_fp8_buffer, _buffer_delete_key_fwd, _buffer_delete_key_bwd
    if forward:
        if (
            _buffer_delete_key_fwd is not None
            and _buffer_delete_key_fwd in _global_fp8_buffer
        ):
            del _global_fp8_buffer[_buffer_delete_key_fwd]
    else:
        if (
            _buffer_delete_key_bwd is not None
            and _buffer_delete_key_bwd in _global_fp8_buffer
        ):
            del _global_fp8_buffer[_buffer_delete_key_bwd]