fp8.py 24.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""FP8 utilities for TransformerEngine"""
Sangkug Lym's avatar
Sangkug Lym committed
6
import os
Przemek Tredak's avatar
Przemek Tredak committed
7
from contextlib import contextmanager
8
from collections import deque
9
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11

import torch
12
import transformer_engine_torch as tex
Przemek Tredak's avatar
Przemek Tredak committed
13
14
15
from transformer_engine.common.recipe import DelayedScaling, Format

from .constants import dist_group_type
16
from .utils import get_device_compute_capability
17
from .jit import jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
18

19

20
__all__ = ["fp8_autocast", "fp8_model_init"]
21
22
23


def check_fp8_support() -> Tuple[bool, str]:
24
    """Return if fp8 support is available"""
25
    if get_device_compute_capability() >= (9, 0):  # hopper and above
26
        return True, ""
27
    if get_device_compute_capability() < (8, 9):  # pre-ada
28
29
30
31
32
33
34
35
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if tex.get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if float(torch.version.cuda) < 12.1:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


36
def get_default_fp8_recipe() -> DelayedScaling:
37
    """FP8 recipe with default args."""
38
    return DelayedScaling()
39
40


41
42
43
44
45
46
47
48
49
def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype:
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return torch.float8_e4m3fn
    return torch.float8_e5m2fn


50
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
51
52
53
54
55
56
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return tex.DType.kFloat8E4M3
    return tex.DType.kFloat8E5M2
57
58


59
def get_fp8_max(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
60
61
62
63
64
65
66
67
    """Get max representible FP8 value."""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return Format.E4M3.value.max_fwd
    return Format.E5M2.value.max_fwd


68
69
70
class FP8GlobalStateManager:
    """Class to keep track of and manipulate the global
    FP8 state at different stages of execution.
71
    """
72

73
74
75
76
    FP8_ENABLED = False
    FP8_CALIBRATION = False
    FP8_RECIPE = None
    FP8_DISTRIBUTED_GROUP = None
77
    FP8_PARAMETERS = False
78
    IS_FIRST_FP8_MODULE = False
79
    FP8_GRAPH_CAPTURING = False
80
    FP8_AUTOCAST_DEPTH = 0
81
82
83
84
    global_amax_buffer = {}
    global_amax_history_buffer = {}
    global_scale_buffer = {}
    global_scale_inv_buffer = {}
85
86
87
    fp8_tensors_recompute_buffer = []
    fp8_available = None
    reason_for_no_fp8 = ""
88
89
90
91
    autocast_arguments = {}
    autocast_to_fp8_params = {}
    fp8_param_to_autocast = {}
    skip_fp8_weight_update_tensor = None
92

93
94
95
96
97
98
99
    @classmethod
    def reset(cls) -> None:
        """Reset the global state"""
        cls.FP8_ENABLED = False
        cls.FP8_CALIBRATION = False
        cls.FP8_RECIPE = None
        cls.FP8_DISTRIBUTED_GROUP = None
100
        cls.FP8_PARAMETERS = False
101
        cls.IS_FIRST_FP8_MODULE = False
102
        cls.FP8_GRAPH_CAPTURING = False
103
        cls.FP8_AUTOCAST_DEPTH = 0
104
105
106
107
        cls.global_amax_buffer = {}
        cls.global_amax_history_buffer = {}
        cls.global_scale_buffer = {}
        cls.global_scale_inv_buffer = {}
108
109
110
        cls.fp8_tensors_recompute_buffer = []
        cls.fp8_available = None
        cls.reason_for_no_fp8 = ""
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        cls.autocast_arguments = {}
        cls.skip_fp8_weight_update_tensor = None

    @classmethod
    def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
        """`skip_fp8_weight_update_tensor` inplace setter."""
        if cls.skip_fp8_weight_update_tensor is None:
            cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
        cls.skip_fp8_weight_update_tensor.fill_(skip)

    @classmethod
    def get_skip_fp8_weight_update_tensor(cls) -> None:
        """`skip_fp8_weight_update_tensor` getter."""
        return cls.skip_fp8_weight_update_tensor
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    @classmethod
    def is_fp8_available(cls) -> Tuple[bool, str]:
        """Return if fp8 support is available"""
        if cls.fp8_available is None:
            cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
        return cls.fp8_available, cls.reason_for_no_fp8

    @staticmethod
    def get_meta_tensor_key(forward: bool = True) -> str:
        """Returns scaling key in `fp8_meta`."""
        if forward:
            return "scaling_fwd"
        return "scaling_bwd"

    @staticmethod
141
142
143
    def get_fwd_bwd_key(forward: bool = True) -> str:
        """Convert bool `forward` to string."""
        return "forward" if forward else "backward"
144
145

    @classmethod
146
147
148
149
150
151
    def get_buffer_info(cls) -> str:
        """
        Returns a key for `fp8_meta` that stores the module's index
        in the global buffers along with autocast information.
        """
        return "buffer_index_and_autocast_key"
152
153

    @classmethod
154
155
156
157
158
159
160
161
162
    def get_key_in_buffer(
        cls,
        forward: bool,
        fp8_recipe: DelayedScaling,
        fp8_group: dist_group_type,
    ) -> str:
        """Returns a key into the global FP8 buffers."""
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        fwd_bwd_key = cls.get_fwd_bwd_key(forward)
163
        return f"{fwd_bwd_key}_{autocast_key}"
164
165

    @classmethod
166
    def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
167
        """Splits buffer key into relevant parts."""
168
        forward, autocast_key = key.split("_", 1)
169
        forward = forward == "forward"
170
        return forward, autocast_key
171
172

    @classmethod
173
174
175
    def add_fp8_tensors_to_global_buffer(
        cls,
        fp8_meta: Dict[str, Any],
176
    ) -> None:
177
178
179
180
181
182
183
184
185
186
187
188
189
        """
        The amax reduction process happens completely outside the FP8 modules.
        To participate in the reduction, the only role played by a module is
        to call this function in order to append it's FP8 tensor into a global
        buffer. There are 5 global buffers maintained, one each for amax, amax
        history, scale, scale-inverse, and non-weight-mask. Each buffer has
        keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
        to indicate the type of FP8 tensor, since the forward and backward
        reductions happen separately.

        Note: For CG capture, this method is called from the graphed
        wrapper. For non CG case, it's called from within the module.
        """
190

191
192
193
194
195
        # Every module must call this function exactly once since
        # the amax tensors are static. Ensures that compatibility
        # with non-graphed modules is maintained.
        index_in_buffer = cls.get_buffer_info()  # Same index for fwd/bwd fp8 tensors.
        if index_in_buffer in fp8_meta:
196
197
            return

198
199
        fp8_meta[index_in_buffer] = []
        for forward in (True, False):
200
201
202
203
204
            fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
            if fp8_meta_tensor_key not in fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue

205
            key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
206
207
208
209
210
211
212
213
214

            if key not in cls.global_amax_buffer:
                cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
                cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
                cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
                cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv]
            else:
                cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
                cls.global_amax_history_buffer[key].append(
215
216
                    fp8_meta[fp8_meta_tensor_key].amax_history
                )
217
218
219
220
                cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
                cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv)
            fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
            fp8_meta[index_in_buffer].append(key)
221
222
223
224
225
226
227
228
229
230
231

    @classmethod
    def is_fp8_enabled(cls) -> bool:
        """Is FP8 enabled"""
        return cls.FP8_ENABLED

    @classmethod
    def is_fp8_calibration(cls) -> bool:
        """Is FP8 calibration"""
        return cls.FP8_CALIBRATION

232
233
234
235
236
    @classmethod
    def with_fp8_parameters(cls) -> bool:
        """Should the parameters be stored as FP8"""
        return cls.FP8_PARAMETERS

237
238
239
240
241
    @classmethod
    def fp8_graph_capturing(cls) -> bool:
        """Is CUDA graph capture under way?"""
        return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()

242
243
244
245
246
247
248
249
250
251
252
253
    @classmethod
    def is_first_fp8_module(cls):
        """Returns `True` only the first time when called multiple
        times from within the same `fp8_autocast` context.
        """
        tmp = cls.IS_FIRST_FP8_MODULE
        cls.IS_FIRST_FP8_MODULE = False
        return tmp

    @classmethod
    def get_fp8_recipe(cls) -> DelayedScaling:
        """Return the fp8 recipe"""
254
255
256
        if cls.FP8_RECIPE is not None:
            return cls.FP8_RECIPE
        return get_default_fp8_recipe()
257
258
259
260
261
262
263
264
265
266
267
268
269
270

    @classmethod
    def get_fp8_group(cls) -> Union[dist_group_type, None]:
        """Return the fp8 group for scale/amax comm"""
        return cls.FP8_DISTRIBUTED_GROUP

    @classmethod
    def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]:
        """FP8 autocast state getter"""
        return (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
271
            cls.IS_FIRST_FP8_MODULE,
272
273
            cls.FP8_GRAPH_CAPTURING,
        )
274
275
276

    @classmethod
    def set_fp8_autocast_state(
277
        cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
278
279
    ) -> None:
        """FP8 autocast state setter"""
280
281
282
283
284
285
286
287
        (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
            cls.IS_FIRST_FP8_MODULE,
            cls.FP8_GRAPH_CAPTURING,
        ) = fp8_state
288
289

    @staticmethod
290
    def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None:
291
292
        """Reduce tensor across given group."""
        if torch.distributed.is_initialized():
293
            torch.distributed.all_reduce(
294
295
296
                tensor,
                op=torch.distributed.ReduceOp.MAX,
                group=group,
297
                async_op=False,
298
            )
299

300
    @classmethod
301
    def reduce_and_update_fp8_tensors(
302
303
304
305
        cls,
        forward: bool = True,
    ) -> None:
        """Concatenate, reduce, and split amaxes in the global buffer."""
306
307
        for buffer_key, amax_buffer in cls.global_amax_buffer.items():
            # Check for forward or backward reduction.
308
            fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
309
310
311
312
313
314
315
316
317
318
            if fwd_update != forward:
                continue
            if len(amax_buffer) == 0:
                continue

            # Retrieve autocast specific args and concat amaxes.
            recipe, group = cls.autocast_arguments[autocast_key]
            contiguous_amax = torch.cat(amax_buffer)

            # Reduction.
319
320
            if (
                recipe.reduce_amax
321
                and torch.distributed.is_initialized()
322
323
                and torch.distributed.get_world_size(group=group) > 1
            ):
324
325
326
                cls.reduce_tensor_across_group_op_max(contiguous_amax, group)

            # Amax and scale update.
327
328
329
330
331
            unfused_update = (
                bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
                or callable(recipe.amax_compute_algo)
                or callable(recipe.scaling_factor_compute_algo)
            )
332
333
334
335
336
337
338
339
340
341
342

            if not unfused_update:
                tex.fused_amax_and_scale_update_after_reduction(
                    contiguous_amax,
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                    cls.global_scale_inv_buffer[buffer_key],
                    recipe.amax_compute_algo,
                    get_fp8_te_dtype(recipe, forward),
                    recipe.margin,
                )
343
            else:
344
                split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
345

346
347
348
349
350
351
                for amax_history, scale, scale_inv in zip(
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                    cls.global_scale_inv_buffer[buffer_key],
                ):
                    _amax_and_scale_update(
352
353
                        amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe
                    )
354

355
356
357
358
359
360
361
362
363
364
365
    @classmethod
    def get_unique_autocast_key(
        cls,
        recipe: Optional[DelayedScaling] = None,
        group: Optional[dist_group_type] = None,
    ):
        """
        For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
        Safely using `hash` as we never cross checkpoint boundaries.
        """
        return f"{str(recipe)}:{hash(group)}"
Przemek Tredak's avatar
Przemek Tredak committed
366

367
368
369
370
371
372
373
    @classmethod
    def fp8_autocast_enter(
        cls,
        enabled: bool = False,
        calibrating: bool = False,
        fp8_recipe: Optional[DelayedScaling] = None,
        fp8_group: Optional[dist_group_type] = None,
374
        _graph: bool = False,
375
376
    ) -> None:
        """Set state and tracking variables for entry into FP8 region."""
377
378
379
380
381

        fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)

382
383
        cls.FP8_ENABLED = enabled
        cls.FP8_CALIBRATION = calibrating
384
        cls.FP8_RECIPE = fp8_recipe
385
        cls.FP8_DISTRIBUTED_GROUP = fp8_group
386
        cls.FP8_GRAPH_CAPTURING = _graph
387
388
389
390

        if cls.FP8_AUTOCAST_DEPTH == 0:
            cls.IS_FIRST_FP8_MODULE = True
        cls.FP8_AUTOCAST_DEPTH += 1
Przemek Tredak's avatar
Przemek Tredak committed
391

392
393
394
        if enabled:
            fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
            assert fp8_available, reason_for_no_fp8
Przemek Tredak's avatar
Przemek Tredak committed
395

396
    @classmethod
397
    def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
398
399
        """Set state and tracking variables for exit from FP8 region."""
        cls.FP8_AUTOCAST_DEPTH -= 1
400
401
402
403
        # Reduce only the non-FP8 weight modules here.
        # FP8 weight modules are reduced at the end of the optimizer
        # step after the weight amax is populated.
        if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
404
            cls.reduce_and_update_fp8_tensors(forward=True)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

    @classmethod
    def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Copy the scaling factors and amaxes for recompute forward phase
        to ensure both forward steps are numerically same.
        """
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"

        to_copy = [
            fp8_meta["scaling_fwd"].amax_history.clone(),
            fp8_meta["scaling_fwd"].scale.clone(),
            fp8_meta["scaling_fwd"].scale_inv.clone(),
        ]

        if buffer_position_key in fp8_meta:
            cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
        else:
            if len(cls.fp8_tensors_recompute_buffer) == 0:
                cls.fp8_tensors_recompute_buffer = [deque()]
            else:
                cls.fp8_tensors_recompute_buffer.append(deque())
            cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
            fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1

    @classmethod
    def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Switch to the copied scaling factors and amaxes from phase
        1 forward for indentical numerical outputs.
        """

        # Store updated amaxes and scales from phase 1 post forward.
        fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
        fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
        fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv

        # Retrieve stashed amaxes and scales from phase 1 pre forward.
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
442
        stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
443
444
445
446
447
448
449
450
451
452
453
454

        # Replace amaxes and scales with stashed values for phase 2 forward
        fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
        fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
        fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]

    @staticmethod
    def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
        """Restore latest scaling factors and amaxes after recompute forward run."""
        fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
        fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
        fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
Przemek Tredak's avatar
Przemek Tredak committed
455
456


457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
    """
    Context manager for FP8 initialization of parameters.

    Example usage:

    .. code-block:: python

        with fp8_model_init(enabled=True):
            model = transformer_engine.pytorch.Linear(768, 768)

    Parameters
    ----------
    enabled: bool, default = `True`
             when enabled, Transformer Engine modules created inside this `fp8_model_init`
             region will hold only FP8 copies of its parameters, as opposed to the default
             behavior where both higher precision and FP8 copies are present. Setting this
             option to `True` may result in lower memory consumption and is especially
             useful for scenarios like:

             * full model training using optimizer with master weights, where the high
               precision copies of weights are already present in the optimizer.
             * inference, where only the FP8 copies of the parameters are used.
             * LoRA-like fine-tuning, where the main parameters of the model do not change.

             This functionality is *EXPERIMENTAL*.
    """
485
486
    _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
    FP8GlobalStateManager.FP8_PARAMETERS = enabled
487
488
489
    try:
        yield
    finally:
490
        FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
491
492


Przemek Tredak's avatar
Przemek Tredak committed
493
494
@contextmanager
def fp8_autocast(
495
    enabled: bool = True,
schetlur-nv's avatar
schetlur-nv committed
496
    calibrating: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
497
498
    fp8_recipe: Optional[DelayedScaling] = None,
    fp8_group: Optional[dist_group_type] = None,
499
    _graph: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
) -> None:
    """
    Context manager for FP8 usage.

    .. code-block:: python

        with fp8_autocast(enabled=True):
            out = model(inp)

    .. note::

        Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
        with shapes where both dimensions are divisible by 16. In terms of the input to the full
        Transformer network, this typically requires padding sequence length to be multiple of 16.

515
516
517
518
519
520
521
522
    .. note::

        When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
        inside a single `fp8_autocast` region. This is unsupported behavior because the amax
        reduction is handled during the exit of the `fp8_autocast` context. Calling the same
        module more than once inside an `fp8_autocast` region overrides the amax tensors
        before reduction can occur.

Przemek Tredak's avatar
Przemek Tredak committed
523
524
    Parameters
    ----------
525
    enabled: bool, default = `True`
Przemek Tredak's avatar
Przemek Tredak committed
526
             whether or not to enable fp8
527
528
529
530
531
    calibrating: bool, default = `False`
                 calibration mode allows collecting statistics such as amax and scale
                 data of fp8 tensors even when executing without fp8 enabled. This is
                 useful for saving an inference ready fp8 checkpoint while training
                 using a higher precision.
Przemek Tredak's avatar
Przemek Tredak committed
532
533
534
535
536
537
    fp8_recipe: recipe.DelayedScaling, default = `None`
                recipe used for FP8 training.
    fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
               distributed group over which amaxes for the fp8 tensors
               are reduced at the end of each training step.
    """
538
    fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
539
540
541
542
543
544
545
    FP8GlobalStateManager.fp8_autocast_enter(
        enabled=enabled,
        calibrating=calibrating,
        fp8_recipe=fp8_recipe,
        fp8_group=fp8_group,
        _graph=_graph,
    )
Przemek Tredak's avatar
Przemek Tredak committed
546
547
548
    try:
        yield
    finally:
549
        FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
550
        FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
Przemek Tredak's avatar
Przemek Tredak committed
551
552


553
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
554
    """Update amax history and set next amax to zero."""
555
    if amax_history.shape[0] > 1:
556
557
        new_amax_history = torch.roll(amax_history, -1, 0)
        amax_history.copy_(new_amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
558
559
560
561
    amax_history[0].fill_(0.0)
    return amax_history


562
@torch.jit.script
563
def _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
564
565
566
567
568
569
570
    amax_history: torch.Tensor,
    amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Default function to obtain amax from history."""
    if amax_compute_algo == "max":
        amax = torch.max(amax_history, dim=0).values
    else:  # amax_compute_algo == "most_recent"
571
        amax = amax_history[0].clone()
Przemek Tredak's avatar
Przemek Tredak committed
572

573
    amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
574
575
576
    return amax_history, amax


577
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
578
579
580
581
582
def _default_sf_compute(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    margin: int,
583
    _fp32_max: float = torch.finfo(torch.float32).max,  # finfo not available in jitter
Przemek Tredak's avatar
Przemek Tredak committed
584
) -> torch.Tensor:
585
586
587
588
589
590
591
592
593
594
595
596
    """Default function to convert amax to scaling factor.
    Computing the scaling factor requires consideration of the following scenarios:
    1. amax == 0:
       No action is possible, set scale to the previous scale (or 1).
    2. 0 < amax < tiny_amax
       The amax is too tiny that the scale becomes infinite in FP32.
       Set scale = FP32_max
    3. tiny_amax <= amax < FP32_max:
       Set scale = FP8_max (or scaled_max) / amax
    4. When amax == inf or amax == nan:
       No action is possible, set scale to the previous scale (or 1).
    """
597
    sf = (fp8_max / amax) / (2**margin)
Przemek Tredak's avatar
Przemek Tredak committed
598
599
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)
600
    sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
601
602
    scale.copy_(sf)
    return scale
603

Przemek Tredak's avatar
Przemek Tredak committed
604

605
def _compute_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
606
    amax_history: torch.Tensor,
607
    amax_compute_algo: Union[Callable, str],
Przemek Tredak's avatar
Przemek Tredak committed
608
609
610
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Obtain the amax from the history."""

611
612
    if callable(amax_compute_algo):
        amax = amax_compute_algo(amax_history)
613
        amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
614
        return amax_history, amax
615
    return _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
616
        amax_history,
617
        amax_compute_algo,
Przemek Tredak's avatar
Przemek Tredak committed
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    )


def _compute_scaling_factor(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
) -> torch.Tensor:
    """Convert amax to scaling factor."""

    if recipe.scaling_factor_compute_algo is None:
        return _default_sf_compute(
            amax,
            scale,
            fp8_max,
            recipe.margin,
        )
    return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)


639
640
641
642
643
644
def _amax_and_scale_update(
    amax_history: torch.Tensor,
    scale: torch.Tensor,
    scale_inv: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
Przemek Tredak's avatar
Przemek Tredak committed
645
) -> None:
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    """Updates FP8 meta tensors."""
    new_amax_history, amax = _compute_amax_and_update_history(
        amax_history,
        recipe.amax_compute_algo,
    )
    new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
    scale.copy_(new_scale)
    scale_inv.copy_(1.0 / new_scale)
    amax_history.copy_(new_amax_history)


def split_and_copy(
    buffer: torch.Tensor,
    outputs: List[torch.Tensor],
    chunk_sizes: List[int],
) -> None:
    """Split `buffer` by `chunk_sizes` and copy into `outputs`."""
    splits = buffer.split(chunk_sizes)
    torch._foreach_copy_(outputs, splits)