fp8.py 31.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""FP8 utilities for TransformerEngine"""
6
7
8
from __future__ import annotations

import abc
Sangkug Lym's avatar
Sangkug Lym committed
9
import os
Przemek Tredak's avatar
Przemek Tredak committed
10
from contextlib import contextmanager
11
from collections import deque
12
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
13
14

import torch
15
import transformer_engine_torch as tex
16
17
18
19
20
21
22
from transformer_engine.common.recipe import (
    Recipe,
    DelayedScaling,
    Format,
    MXFP8BlockScaling,
    Float8CurrentScaling,
)
Przemek Tredak's avatar
Przemek Tredak committed
23
24

from .constants import dist_group_type
25
from .utils import get_device_compute_capability
26
from .jit import jit_fuser
yuguo's avatar
yuguo committed
27
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
28

29

30
__all__ = ["fp8_autocast", "fp8_model_init"]
31
32
33


def check_fp8_support() -> Tuple[bool, str]:
34
    """Return if fp8 support is available"""
yuguo's avatar
yuguo committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    if IS_HIP_EXTENSION:
        if get_device_compute_capability() == (9, 4):
            return True, ""
        else:
            return False, "DCU not support fp8 for now"
    else:
        if get_device_compute_capability() >= (9, 0):  # hopper and above
            return True, ""
        if get_device_compute_capability() < (8, 9):  # pre-ada
            return False, "Device compute capability 8.9 or higher required for FP8 execution."
        if tex.get_cublasLt_version() < 120103:
            return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
        if float(torch.version.cuda) < 12.1:
            return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
49
50
51
    return True, ""


52
53
54
55
56
57
58
59
def check_mxfp8_support() -> Tuple[bool, str]:
    """Return if fp8 support is available"""
    if get_device_compute_capability() >= (10, 0):  # blackwell and above
        return True, ""
    return False, "Device compute capability 10.0 or higher required for MXFP8 execution."


def get_default_fp8_recipe() -> Recipe:
60
    """FP8 recipe with default args."""
61
62
    if get_device_compute_capability() >= (10, 0):  # blackwell and above
        return MXFP8BlockScaling()
63
    return DelayedScaling()
64
65


66
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
67
68
69
70
71
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return torch.float8_e4m3fn
72
    return torch.float8_e5m2
73
74


75
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
76
77
78
79
80
81
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return tex.DType.kFloat8E4M3
    return tex.DType.kFloat8E5M2
82
83


84
def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
85
86
87
88
89
90
91
92
    """Get max representible FP8 value."""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return Format.E4M3.value.max_fwd
    return Format.E5M2.value.max_fwd


93
94
95
class FP8GlobalStateManager:
    """Class to keep track of and manipulate the global
    FP8 state at different stages of execution.
96
    """
97

98
99
100
101
    FP8_ENABLED = False
    FP8_CALIBRATION = False
    FP8_RECIPE = None
    FP8_DISTRIBUTED_GROUP = None
102
    FP8_PARAMETERS = False
103
    IS_FIRST_FP8_MODULE = False
104
    FP8_GRAPH_CAPTURING = False
105
    FP8_AUTOCAST_DEPTH = 0
106
107
108
    global_amax_buffer = {}
    global_amax_history_buffer = {}
    global_scale_buffer = {}
109
110
111
    fp8_tensors_recompute_buffer = []
    fp8_available = None
    reason_for_no_fp8 = ""
112
113
114
115
    autocast_arguments = {}
    autocast_to_fp8_params = {}
    fp8_param_to_autocast = {}
    skip_fp8_weight_update_tensor = None
116
117
    mxfp8_available = None
    reason_for_no_mxfp8 = ""
118

119
120
121
122
123
124
125
    @classmethod
    def reset(cls) -> None:
        """Reset the global state"""
        cls.FP8_ENABLED = False
        cls.FP8_CALIBRATION = False
        cls.FP8_RECIPE = None
        cls.FP8_DISTRIBUTED_GROUP = None
126
        cls.FP8_PARAMETERS = False
127
        cls.IS_FIRST_FP8_MODULE = False
128
        cls.FP8_GRAPH_CAPTURING = False
129
        cls.FP8_AUTOCAST_DEPTH = 0
130
131
132
        cls.global_amax_buffer = {}
        cls.global_amax_history_buffer = {}
        cls.global_scale_buffer = {}
133
134
135
        cls.fp8_tensors_recompute_buffer = []
        cls.fp8_available = None
        cls.reason_for_no_fp8 = ""
136
        cls.autocast_arguments = {}
137
138
        cls.autocast_to_fp8_params = {}
        cls.fp8_param_to_autocast = {}
139
        cls.skip_fp8_weight_update_tensor = None
140
141
        cls.mxfp8_available = None
        cls.reason_for_no_mxfp8 = ""
142
143
144
145
146
147
148
149
150
151
152
153

    @classmethod
    def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
        """`skip_fp8_weight_update_tensor` inplace setter."""
        if cls.skip_fp8_weight_update_tensor is None:
            cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
        cls.skip_fp8_weight_update_tensor.fill_(skip)

    @classmethod
    def get_skip_fp8_weight_update_tensor(cls) -> None:
        """`skip_fp8_weight_update_tensor` getter."""
        return cls.skip_fp8_weight_update_tensor
154

155
156
157
158
159
160
161
    @classmethod
    def is_fp8_available(cls) -> Tuple[bool, str]:
        """Return if fp8 support is available"""
        if cls.fp8_available is None:
            cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
        return cls.fp8_available, cls.reason_for_no_fp8

162
163
164
165
166
167
168
    @classmethod
    def is_mxfp8_available(cls) -> Tuple[bool, str]:
        """Return if MXFP8/current scaling support is available."""
        if cls.mxfp8_available is None:
            cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
        return cls.mxfp8_available, cls.reason_for_no_mxfp8

169
170
171
172
173
174
175
176
    @staticmethod
    def get_meta_tensor_key(forward: bool = True) -> str:
        """Returns scaling key in `fp8_meta`."""
        if forward:
            return "scaling_fwd"
        return "scaling_bwd"

    @staticmethod
177
178
179
    def get_fwd_bwd_key(forward: bool = True) -> str:
        """Convert bool `forward` to string."""
        return "forward" if forward else "backward"
180
181

    @classmethod
182
183
184
185
186
187
    def get_buffer_info(cls) -> str:
        """
        Returns a key for `fp8_meta` that stores the module's index
        in the global buffers along with autocast information.
        """
        return "buffer_index_and_autocast_key"
188
189

    @classmethod
190
191
192
    def get_key_in_buffer(
        cls,
        forward: bool,
193
        fp8_recipe: Recipe,
194
195
196
197
198
        fp8_group: dist_group_type,
    ) -> str:
        """Returns a key into the global FP8 buffers."""
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        fwd_bwd_key = cls.get_fwd_bwd_key(forward)
199
        return f"{fwd_bwd_key}_{autocast_key}"
200
201

    @classmethod
202
    def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
203
        """Splits buffer key into relevant parts."""
204
        forward, autocast_key = key.split("_", 1)
205
        forward = forward == "forward"
206
        return forward, autocast_key
207
208

    @classmethod
209
210
211
    def add_fp8_tensors_to_global_buffer(
        cls,
        fp8_meta: Dict[str, Any],
212
    ) -> None:
213
        """
214
215
        Delayed scaling only.

216
217
218
219
220
221
222
223
224
225
226
227
        The amax reduction process happens completely outside the FP8 modules.
        To participate in the reduction, the only role played by a module is
        to call this function in order to append it's FP8 tensor into a global
        buffer. There are 5 global buffers maintained, one each for amax, amax
        history, scale, scale-inverse, and non-weight-mask. Each buffer has
        keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
        to indicate the type of FP8 tensor, since the forward and backward
        reductions happen separately.

        Note: For CG capture, this method is called from the graphed
        wrapper. For non CG case, it's called from within the module.
        """
228

229
230
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
231
232
            return

233
234
235
236
237
        # Every module must call this function exactly once since
        # the amax tensors are static. Ensures that compatibility
        # with non-graphed modules is maintained.
        index_in_buffer = cls.get_buffer_info()  # Same index for fwd/bwd fp8 tensors.
        if index_in_buffer in fp8_meta:
238
239
            return

240
241
        fp8_meta[index_in_buffer] = []
        for forward in (True, False):
242
243
244
245
246
            fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
            if fp8_meta_tensor_key not in fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue

247
            key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
248
249
250
251
252
253
254
255

            if key not in cls.global_amax_buffer:
                cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
                cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
                cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
            else:
                cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
                cls.global_amax_history_buffer[key].append(
256
257
                    fp8_meta[fp8_meta_tensor_key].amax_history
                )
258
259
260
                cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
            fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
            fp8_meta[index_in_buffer].append(key)
261
262
263
264
265
266
267
268
269
270
271

    @classmethod
    def is_fp8_enabled(cls) -> bool:
        """Is FP8 enabled"""
        return cls.FP8_ENABLED

    @classmethod
    def is_fp8_calibration(cls) -> bool:
        """Is FP8 calibration"""
        return cls.FP8_CALIBRATION

272
273
274
275
276
    @classmethod
    def with_fp8_parameters(cls) -> bool:
        """Should the parameters be stored as FP8"""
        return cls.FP8_PARAMETERS

277
278
279
280
281
    @classmethod
    def fp8_graph_capturing(cls) -> bool:
        """Is CUDA graph capture under way?"""
        return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()

282
283
284
285
286
287
288
289
290
291
    @classmethod
    def is_first_fp8_module(cls):
        """Returns `True` only the first time when called multiple
        times from within the same `fp8_autocast` context.
        """
        tmp = cls.IS_FIRST_FP8_MODULE
        cls.IS_FIRST_FP8_MODULE = False
        return tmp

    @classmethod
292
    def get_fp8_recipe(cls) -> Recipe:
293
        """Return the fp8 recipe"""
294
295
296
        if cls.FP8_RECIPE is not None:
            return cls.FP8_RECIPE
        return get_default_fp8_recipe()
297
298
299
300
301
302
303

    @classmethod
    def get_fp8_group(cls) -> Union[dist_group_type, None]:
        """Return the fp8 group for scale/amax comm"""
        return cls.FP8_DISTRIBUTED_GROUP

    @classmethod
304
    def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]:
305
306
307
308
309
310
        """FP8 autocast state getter"""
        return (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
311
            cls.IS_FIRST_FP8_MODULE,
312
313
            cls.FP8_GRAPH_CAPTURING,
        )
314
315
316

    @classmethod
    def set_fp8_autocast_state(
317
        cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
318
319
    ) -> None:
        """FP8 autocast state setter"""
320
321
322
323
324
325
326
327
        (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
            cls.IS_FIRST_FP8_MODULE,
            cls.FP8_GRAPH_CAPTURING,
        ) = fp8_state
328
329

    @staticmethod
330
    def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None:
331
332
        """Reduce tensor across given group."""
        if torch.distributed.is_initialized():
333
            torch.distributed.all_reduce(
334
335
336
                tensor,
                op=torch.distributed.ReduceOp.MAX,
                group=group,
337
                async_op=False,
338
            )
339

340
    @classmethod
341
    def reduce_and_update_fp8_tensors(
342
343
344
        cls,
        forward: bool = True,
    ) -> None:
345
346
        """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
        # global_amax_buffer should only be non-empty for fp8 delayed scaling
347
348
        for buffer_key, amax_buffer in cls.global_amax_buffer.items():
            # Check for forward or backward reduction.
349
            fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
350
351
352
353
354
355
356
357
358
359
            if fwd_update != forward:
                continue
            if len(amax_buffer) == 0:
                continue

            # Retrieve autocast specific args and concat amaxes.
            recipe, group = cls.autocast_arguments[autocast_key]
            contiguous_amax = torch.cat(amax_buffer)

            # Reduction.
360
361
            if (
                recipe.reduce_amax
362
                and torch.distributed.is_initialized()
363
364
                and torch.distributed.get_world_size(group=group) > 1
            ):
365
366
367
                cls.reduce_tensor_across_group_op_max(contiguous_amax, group)

            # Amax and scale update.
368
369
370
371
372
            unfused_update = (
                bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
                or callable(recipe.amax_compute_algo)
                or callable(recipe.scaling_factor_compute_algo)
            )
373
374
375
376
377
378
379
380
381
382

            if not unfused_update:
                tex.fused_amax_and_scale_update_after_reduction(
                    contiguous_amax,
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                    recipe.amax_compute_algo,
                    get_fp8_te_dtype(recipe, forward),
                    recipe.margin,
                )
383
            else:
384
                split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
385

386
                for amax_history, scale in zip(
387
388
389
390
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                ):
                    _amax_and_scale_update(
391
                        amax_history, scale, get_fp8_max(recipe, forward), recipe
392
                    )
393

394
395
396
    @classmethod
    def get_unique_autocast_key(
        cls,
397
        recipe: Optional[Recipe] = None,
398
399
400
401
402
403
404
        group: Optional[dist_group_type] = None,
    ):
        """
        For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
        Safely using `hash` as we never cross checkpoint boundaries.
        """
        return f"{str(recipe)}:{hash(group)}"
Przemek Tredak's avatar
Przemek Tredak committed
405

406
407
408
409
410
    @classmethod
    def fp8_autocast_enter(
        cls,
        enabled: bool = False,
        calibrating: bool = False,
411
        fp8_recipe: Optional[Recipe] = None,
412
        fp8_group: Optional[dist_group_type] = None,
413
        _graph: bool = False,
414
415
    ) -> None:
        """Set state and tracking variables for entry into FP8 region."""
416
417
418
419
420

        fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)

421
422
        cls.FP8_ENABLED = enabled
        cls.FP8_CALIBRATION = calibrating
423
        cls.FP8_RECIPE = fp8_recipe
424
        cls.FP8_DISTRIBUTED_GROUP = fp8_group
425
        cls.FP8_GRAPH_CAPTURING = _graph
426
427
428
429

        if cls.FP8_AUTOCAST_DEPTH == 0:
            cls.IS_FIRST_FP8_MODULE = True
        cls.FP8_AUTOCAST_DEPTH += 1
Przemek Tredak's avatar
Przemek Tredak committed
430

431
432
433
        if enabled:
            fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
            assert fp8_available, reason_for_no_fp8
434
435
436
            if isinstance(fp8_recipe, MXFP8BlockScaling):
                mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
                assert mxfp8_available, reason_for_no_mxfp8
Przemek Tredak's avatar
Przemek Tredak committed
437

438
    @classmethod
439
    def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
440
441
        """Set state and tracking variables for exit from FP8 region."""
        cls.FP8_AUTOCAST_DEPTH -= 1
442
443
444
445
        # Reduce only the non-FP8 weight modules here.
        # FP8 weight modules are reduced at the end of the optimizer
        # step after the weight amax is populated.
        if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
446
447
            # delayed scaling only function, for other recipes (current scaling with any granularity),
            # this is noop for other recipes because cls.global_amax_buffer is empty list
448
            cls.reduce_and_update_fp8_tensors(forward=True)
449
450
451
452
453
454

    @classmethod
    def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Copy the scaling factors and amaxes for recompute forward phase
        to ensure both forward steps are numerically same.
        """
455

456
457
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
458
459
            return

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"

        to_copy = [
            fp8_meta["scaling_fwd"].amax_history.clone(),
            fp8_meta["scaling_fwd"].scale.clone(),
        ]

        if buffer_position_key in fp8_meta:
            cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
        else:
            if len(cls.fp8_tensors_recompute_buffer) == 0:
                cls.fp8_tensors_recompute_buffer = [deque()]
            else:
                cls.fp8_tensors_recompute_buffer.append(deque())
            cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
            fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1

    @classmethod
    def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Switch to the copied scaling factors and amaxes from phase
        1 forward for indentical numerical outputs.
        """
482
483
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
484
485
            return

486
487
488
489
490
491
        # Store updated amaxes and scales from phase 1 post forward.
        fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
        fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale

        # Retrieve stashed amaxes and scales from phase 1 pre forward.
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
492
        stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
493
494

        # Replace amaxes and scales with stashed values for phase 2 forward
495
496
        fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
        fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
497
498
499
500

    @staticmethod
    def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
        """Restore latest scaling factors and amaxes after recompute forward run."""
501
502
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
503
504
            return

505
506
        fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
        fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
Przemek Tredak's avatar
Przemek Tredak committed
507
508


509
@contextmanager
510
def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None:
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    """
    Context manager for FP8 initialization of parameters.

    Example usage:

    .. code-block:: python

        with fp8_model_init(enabled=True):
            model = transformer_engine.pytorch.Linear(768, 768)

    Parameters
    ----------
    enabled: bool, default = `True`
             when enabled, Transformer Engine modules created inside this `fp8_model_init`
             region will hold only FP8 copies of its parameters, as opposed to the default
             behavior where both higher precision and FP8 copies are present. Setting this
             option to `True` may result in lower memory consumption and is especially
             useful for scenarios like:

             * full model training using optimizer with master weights, where the high
               precision copies of weights are already present in the optimizer.
             * inference, where only the FP8 copies of the parameters are used.
             * LoRA-like fine-tuning, where the main parameters of the model do not change.
534
535
    recipe: transformer_engine.common.recipe.Recipe, default = `None`
            Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
536
537
538

             This functionality is *EXPERIMENTAL*.
    """
539
    _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
540
    _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
541
    FP8GlobalStateManager.FP8_PARAMETERS = enabled
542
    FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
543
544
545
    try:
        yield
    finally:
546
        FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
547
        FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
548
549


Przemek Tredak's avatar
Przemek Tredak committed
550
551
@contextmanager
def fp8_autocast(
552
    enabled: bool = True,
schetlur-nv's avatar
schetlur-nv committed
553
    calibrating: bool = False,
554
    fp8_recipe: Optional[Recipe] = None,
Przemek Tredak's avatar
Przemek Tredak committed
555
    fp8_group: Optional[dist_group_type] = None,
556
    _graph: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
) -> None:
    """
    Context manager for FP8 usage.

    .. code-block:: python

        with fp8_autocast(enabled=True):
            out = model(inp)

    .. note::

        Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
        with shapes where both dimensions are divisible by 16. In terms of the input to the full
        Transformer network, this typically requires padding sequence length to be multiple of 16.

572
573
574
575
576
577
578
579
    .. note::

        When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
        inside a single `fp8_autocast` region. This is unsupported behavior because the amax
        reduction is handled during the exit of the `fp8_autocast` context. Calling the same
        module more than once inside an `fp8_autocast` region overrides the amax tensors
        before reduction can occur.

Przemek Tredak's avatar
Przemek Tredak committed
580
581
    Parameters
    ----------
582
    enabled: bool, default = `True`
Przemek Tredak's avatar
Przemek Tredak committed
583
             whether or not to enable fp8
584
585
586
587
588
    calibrating: bool, default = `False`
                 calibration mode allows collecting statistics such as amax and scale
                 data of fp8 tensors even when executing without fp8 enabled. This is
                 useful for saving an inference ready fp8 checkpoint while training
                 using a higher precision.
589
    fp8_recipe: recipe.Recipe, default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
590
591
592
593
594
                recipe used for FP8 training.
    fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
               distributed group over which amaxes for the fp8 tensors
               are reduced at the end of each training step.
    """
595
    fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
596
597
598
599
600
601
602
    FP8GlobalStateManager.fp8_autocast_enter(
        enabled=enabled,
        calibrating=calibrating,
        fp8_recipe=fp8_recipe,
        fp8_group=fp8_group,
        _graph=_graph,
    )
Przemek Tredak's avatar
Przemek Tredak committed
603
604
605
    try:
        yield
    finally:
606
        FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
607
        FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
Przemek Tredak's avatar
Przemek Tredak committed
608
609


610
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
611
    """Update amax history and set next amax to zero."""
612
    if amax_history.shape[0] > 1:
613
614
        new_amax_history = torch.roll(amax_history, -1, 0)
        amax_history.copy_(new_amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
615
616
617
618
    amax_history[0].fill_(0.0)
    return amax_history


619
@torch.jit.script
620
def _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
621
622
623
624
625
626
627
    amax_history: torch.Tensor,
    amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Default function to obtain amax from history."""
    if amax_compute_algo == "max":
        amax = torch.max(amax_history, dim=0).values
    else:  # amax_compute_algo == "most_recent"
628
        amax = amax_history[0].clone()
Przemek Tredak's avatar
Przemek Tredak committed
629

630
    amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
631
632
633
    return amax_history, amax


634
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
635
636
637
638
639
def _default_sf_compute(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    margin: int,
640
    _fp32_max: float = torch.finfo(torch.float32).max,  # finfo not available in jitter
Przemek Tredak's avatar
Przemek Tredak committed
641
) -> torch.Tensor:
642
643
644
645
646
647
648
649
650
651
652
653
    """Default function to convert amax to scaling factor.
    Computing the scaling factor requires consideration of the following scenarios:
    1. amax == 0:
       No action is possible, set scale to the previous scale (or 1).
    2. 0 < amax < tiny_amax
       The amax is too tiny that the scale becomes infinite in FP32.
       Set scale = FP32_max
    3. tiny_amax <= amax < FP32_max:
       Set scale = FP8_max (or scaled_max) / amax
    4. When amax == inf or amax == nan:
       No action is possible, set scale to the previous scale (or 1).
    """
654
    sf = (fp8_max / amax) / (2**margin)
Przemek Tredak's avatar
Przemek Tredak committed
655
656
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)
657
    sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
658
659
    scale.copy_(sf)
    return scale
660

Przemek Tredak's avatar
Przemek Tredak committed
661

662
def _compute_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
663
    amax_history: torch.Tensor,
664
    amax_compute_algo: Union[Callable, str],
Przemek Tredak's avatar
Przemek Tredak committed
665
666
667
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Obtain the amax from the history."""

668
669
    if callable(amax_compute_algo):
        amax = amax_compute_algo(amax_history)
670
        amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
671
        return amax_history, amax
672
    return _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
673
        amax_history,
674
        amax_compute_algo,
Przemek Tredak's avatar
Przemek Tredak committed
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    )


def _compute_scaling_factor(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
) -> torch.Tensor:
    """Convert amax to scaling factor."""

    if recipe.scaling_factor_compute_algo is None:
        return _default_sf_compute(
            amax,
            scale,
            fp8_max,
            recipe.margin,
        )
    return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)


696
697
698
699
700
def _amax_and_scale_update(
    amax_history: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
Przemek Tredak's avatar
Przemek Tredak committed
701
) -> None:
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
    """Updates FP8 meta tensors."""
    new_amax_history, amax = _compute_amax_and_update_history(
        amax_history,
        recipe.amax_compute_algo,
    )
    new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
    scale.copy_(new_scale)
    amax_history.copy_(new_amax_history)


def split_and_copy(
    buffer: torch.Tensor,
    outputs: List[torch.Tensor],
    chunk_sizes: List[int],
) -> None:
    """Split `buffer` by `chunk_sizes` and copy into `outputs`."""
    splits = buffer.split(chunk_sizes)
    torch._foreach_copy_(outputs, splits)
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765


class RecipeState(abc.ABC):
    """Configuration and state for a quantization recipe.

    This is a builder class for quantizers, which are in turn builder
    classes for quantized tensors.

    This class may pack together the state for multiple quantizers,
    which is helpful for applying fused kernels with less overhead.

    """

    @staticmethod
    def create(
        recipe: Recipe,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> RecipeState:
        """Factory method to create the state for a quantization recipe

        Parameters
        ----------
        recipe: Recipe
            Quantization recipe.
        mode: {"forward", "backward"}
            Training stage where quantization will be performed.
        num_quantizers: int, default = 1
            Number of quantizers to create state for.
        device: torch.device, default = default CUDA device
            Device for quantized tensors.

        Returns
        -------
        RecipeState:
            Quantization recipe state.

        """

        cls = None
        if recipe.delayed():
            cls = DelayedScalingRecipeState
        elif recipe.mxfp8():
            cls = MXFP8BlockScalingRecipeState
766
767
        elif recipe.float8_current_scaling():
            cls = Float8CurrentScalingRecipeState
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
        else:
            raise ValueError("{recipe.__class__.__name__} is not supported")
        return cls(
            recipe,
            mode=mode,
            num_quantizers=num_quantizers,
            device=device,
        )

    @abc.abstractmethod
    def make_quantizers(self) -> list:
        """Convert recipe state to quantizers.

        Quantizers are builder classes for quantized tensors. They are
        typically used to convert a high-precision tensor (e.g. in
        FP32 or BF16) into a quantized tensor (e.g. in FP8).

        """


class DelayedScalingRecipeState(RecipeState):
    """State for FP8 quantization with per-tensor delayed scaling.

    Delayed scaling recipe requires a scaling factor (applied when
    casting to FP8) and a history of max-abs values ("amax") from
    recent FP8 casts for updating the scaling factor. The scale update
    is handled externally by `FP8GlobalStateManager`.

    """

    recipe: DelayedScaling
    mode: str
    dtype: tex.DType
    scale: torch.Tensor
    amax_history: torch.Tensor

    def __init__(
        self,
        recipe: DelayedScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device)
        self.amax_history = torch.zeros(
            recipe.amax_history_len,
            num_quantizers,
            dtype=torch.float32,
            device=device,
        )

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.float8_tensor import Float8Quantizer

        return [
            Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype)
            for i in range(self.num_quantizers)
        ]


838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
class Float8CurrentScalingRecipeState(RecipeState):
    """Configuration for Per-tensor current scaling quantization.

    Per-tensor current quantization does not require state.

    """

    recipe: Float8CurrentScaling
    mode: str
    dtype: tex.DType
    device: torch.device

    def __init__(
        self,
        recipe: Float8CurrentScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.device = device

    def make_quantizers(self) -> list:
        from .tensor.float8_tensor import Float8CurrentScalingQuantizer

        return [
            Float8CurrentScalingQuantizer(self.dtype, device=self.device)
            for i in range(self.num_quantizers)
        ]


877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
class MXFP8BlockScalingRecipeState(RecipeState):
    """Configuration for MXFP8 quantization.

    MXFP8 quantization does not require state.

    """

    recipe: MXFP8BlockScaling
    mode: str
    dtype: tex.DType

    def __init__(
        self,
        recipe: MXFP8BlockScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.mxfp8_tensor import MXFP8Quantizer

        return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]