fp8.py 40.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""FP8 utilities for TransformerEngine"""
6
7
8
from __future__ import annotations

import abc
9
import itertools
Sangkug Lym's avatar
Sangkug Lym committed
10
import os
Przemek Tredak's avatar
Przemek Tredak committed
11
from contextlib import contextmanager
12
from collections import deque
13
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
14
15

import torch
16
import transformer_engine_torch as tex
17
18
19
20
21
22
from transformer_engine.common.recipe import (
    Recipe,
    DelayedScaling,
    Format,
    MXFP8BlockScaling,
    Float8CurrentScaling,
23
    Float8BlockScaling,
24
)
Przemek Tredak's avatar
Przemek Tredak committed
25
26

from .constants import dist_group_type
27
from .utils import get_device_compute_capability
28
from .jit import jit_fuser
yuguo's avatar
yuguo committed
29
from torch.utils.cpp_extension import IS_HIP_EXTENSION
30
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
yuguo's avatar
yuguo committed
31
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
32
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
33

34
__all__ = ["fp8_autocast", "fp8_model_init"]
35

36
37
if IS_HIP_EXTENSION:
    from transformer_engine.pytorch.utils import is_K100_AI, is_BW
38
39

def check_fp8_support() -> Tuple[bool, str]:
40
    """Return if fp8 support is available"""
yuguo's avatar
yuguo committed
41
    if IS_HIP_EXTENSION:
42
43
        if (is_K100_AI() or is_BW()) and  int8_simulation_fp8:
            return True, "DCU turn on fp8 simulation with int8"
yuguo's avatar
yuguo committed
44
45
46
47
48
49
50
51
52
53
54
        else:
            return False, "DCU not support fp8 for now"
    else:
        if get_device_compute_capability() >= (9, 0):  # hopper and above
            return True, ""
        if get_device_compute_capability() < (8, 9):  # pre-ada
            return False, "Device compute capability 8.9 or higher required for FP8 execution."
        if tex.get_cublasLt_version() < 120103:
            return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
        if float(torch.version.cuda) < 12.1:
            return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
55
56
57
    return True, ""


58
59
def check_mxfp8_support() -> Tuple[bool, str]:
    """Return if fp8 support is available"""
60
61
    if get_device_compute_capability() >= (12, 0):
        return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
62
63
64
65
66
    if get_device_compute_capability() >= (10, 0):  # blackwell and above
        return True, ""
    return False, "Device compute capability 10.0 or higher required for MXFP8 execution."


67
68
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
    """Return if fp8 block scaling support is available"""
yuguo's avatar
yuguo committed
69
    if IS_HIP_EXTENSION:
70
71
72
73
        if is_K100_AI() or is_BW():
            return True, ""
        else:
            return False, "DCU not support block_scaling fp8 for now"
74
75
76
77
78
79
80
81
82
    if (
        get_device_compute_capability() >= (9, 0)
        and get_device_compute_capability() < (10, 0)
        and float(torch.version.cuda) >= 12.9
    ):
        return True, ""
    return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."


83
84
85
86
87
88
89
90
91
92
93
94
95
def check_recipe_support(recipe: Recipe) -> None:
    """Check if the given recipe is supported."""
    recipe_supported = True
    unsupported_reason = ""
    if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
        recipe_supported, unsupported_reason = check_fp8_support()
    elif isinstance(recipe, Float8BlockScaling):
        recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
    elif isinstance(recipe, MXFP8BlockScaling):
        recipe_supported, unsupported_reason = check_mxfp8_support()
    assert recipe_supported, unsupported_reason


96
def get_default_fp8_recipe() -> Recipe:
97
    """FP8 recipe with default args."""
98
    if check_mxfp8_support()[0]:
99
        return MXFP8BlockScaling()
100
101
102
    if get_device_compute_capability() >= (12, 0):
        # This is a temporary restriction until MXFP8 is supported for all gemm layouts.
        return Float8CurrentScaling()
103
    return DelayedScaling()
104
105


106
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
107
108
109
110
111
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return torch.float8_e4m3fn
112
    return torch.float8_e5m2
113
114


115
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
116
117
118
119
120
121
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return tex.DType.kFloat8E4M3
    return tex.DType.kFloat8E5M2
122
123


124
def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
125
126
127
128
129
130
131
132
    """Get max representible FP8 value."""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return Format.E4M3.value.max_fwd
    return Format.E5M2.value.max_fwd


133
134
135
class FP8GlobalStateManager:
    """Class to keep track of and manipulate the global
    FP8 state at different stages of execution.
136
    """
137

138
139
140
141
    FP8_ENABLED = False
    FP8_CALIBRATION = False
    FP8_RECIPE = None
    FP8_DISTRIBUTED_GROUP = None
142
    FP8_PARAMETERS = False
143
    HIGH_PRECISION_INIT_VAL = False
144
    IS_FIRST_FP8_MODULE = False
145
    FP8_GRAPH_CAPTURING = False
146
    FP8_AUTOCAST_DEPTH = 0
147
148
149
    global_amax_buffer = {}
    global_amax_history_buffer = {}
    global_scale_buffer = {}
150
151
152
    fp8_tensors_recompute_buffer = []
    fp8_available = None
    reason_for_no_fp8 = ""
153
154
155
156
    autocast_arguments = {}
    autocast_to_fp8_params = {}
    fp8_param_to_autocast = {}
    skip_fp8_weight_update_tensor = None
157
158
    mxfp8_available = None
    reason_for_no_mxfp8 = ""
159
160
    fp8_block_scaling_available = None
    reason_for_no_fp8_block_scaling = None
161

162
163
164
165
166
167
168
    @classmethod
    def reset(cls) -> None:
        """Reset the global state"""
        cls.FP8_ENABLED = False
        cls.FP8_CALIBRATION = False
        cls.FP8_RECIPE = None
        cls.FP8_DISTRIBUTED_GROUP = None
169
        cls.FP8_PARAMETERS = False
170
        cls.HIGH_PRECISION_INIT_VAL = False
171
        cls.IS_FIRST_FP8_MODULE = False
172
        cls.FP8_GRAPH_CAPTURING = False
173
        cls.FP8_AUTOCAST_DEPTH = 0
174
175
176
        cls.global_amax_buffer = {}
        cls.global_amax_history_buffer = {}
        cls.global_scale_buffer = {}
177
178
179
        cls.fp8_tensors_recompute_buffer = []
        cls.fp8_available = None
        cls.reason_for_no_fp8 = ""
180
        cls.autocast_arguments = {}
181
182
        cls.autocast_to_fp8_params = {}
        cls.fp8_param_to_autocast = {}
183
        cls.skip_fp8_weight_update_tensor = None
184
185
        cls.mxfp8_available = None
        cls.reason_for_no_mxfp8 = ""
186
187
        cls.fp8_block_scaling_available = None
        cls.reason_for_no_fp8_block_scaling = ""
188
189
190
191
192
193
194
195
196
197
198
199

    @classmethod
    def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
        """`skip_fp8_weight_update_tensor` inplace setter."""
        if cls.skip_fp8_weight_update_tensor is None:
            cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
        cls.skip_fp8_weight_update_tensor.fill_(skip)

    @classmethod
    def get_skip_fp8_weight_update_tensor(cls) -> None:
        """`skip_fp8_weight_update_tensor` getter."""
        return cls.skip_fp8_weight_update_tensor
200

201
202
203
204
205
206
207
    @classmethod
    def is_fp8_available(cls) -> Tuple[bool, str]:
        """Return if fp8 support is available"""
        if cls.fp8_available is None:
            cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
        return cls.fp8_available, cls.reason_for_no_fp8

208
209
210
211
212
213
214
    @classmethod
    def is_mxfp8_available(cls) -> Tuple[bool, str]:
        """Return if MXFP8/current scaling support is available."""
        if cls.mxfp8_available is None:
            cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
        return cls.mxfp8_available, cls.reason_for_no_mxfp8

215
216
217
218
219
220
221
222
223
    @classmethod
    def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
        """Return if Float8 block scaling support is available."""
        if cls.fp8_block_scaling_available is None:
            cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = (
                check_fp8_block_scaling_support()
            )
        return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling

224
225
226
227
228
229
230
231
    @staticmethod
    def get_meta_tensor_key(forward: bool = True) -> str:
        """Returns scaling key in `fp8_meta`."""
        if forward:
            return "scaling_fwd"
        return "scaling_bwd"

    @staticmethod
232
233
234
    def get_fwd_bwd_key(forward: bool = True) -> str:
        """Convert bool `forward` to string."""
        return "forward" if forward else "backward"
235
236

    @classmethod
237
238
239
240
241
242
    def get_buffer_info(cls) -> str:
        """
        Returns a key for `fp8_meta` that stores the module's index
        in the global buffers along with autocast information.
        """
        return "buffer_index_and_autocast_key"
243
244

    @classmethod
245
246
247
    def get_key_in_buffer(
        cls,
        forward: bool,
248
        fp8_recipe: Recipe,
249
250
251
252
253
        fp8_group: dist_group_type,
    ) -> str:
        """Returns a key into the global FP8 buffers."""
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        fwd_bwd_key = cls.get_fwd_bwd_key(forward)
254
        return f"{fwd_bwd_key}_{autocast_key}"
255
256

    @classmethod
257
    def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
258
        """Splits buffer key into relevant parts."""
259
        forward, autocast_key = key.split("_", 1)
260
        forward = forward == "forward"
261
        return forward, autocast_key
262
263

    @classmethod
264
265
266
    def add_fp8_tensors_to_global_buffer(
        cls,
        fp8_meta: Dict[str, Any],
267
    ) -> None:
268
        """
269
270
        Delayed scaling only.

271
272
273
274
275
276
277
278
279
280
281
282
        The amax reduction process happens completely outside the FP8 modules.
        To participate in the reduction, the only role played by a module is
        to call this function in order to append it's FP8 tensor into a global
        buffer. There are 5 global buffers maintained, one each for amax, amax
        history, scale, scale-inverse, and non-weight-mask. Each buffer has
        keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
        to indicate the type of FP8 tensor, since the forward and backward
        reductions happen separately.

        Note: For CG capture, this method is called from the graphed
        wrapper. For non CG case, it's called from within the module.
        """
283

284
285
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
286
287
            return

288
289
290
291
292
        # Every module must call this function exactly once since
        # the amax tensors are static. Ensures that compatibility
        # with non-graphed modules is maintained.
        index_in_buffer = cls.get_buffer_info()  # Same index for fwd/bwd fp8 tensors.
        if index_in_buffer in fp8_meta:
293
294
            return

295
296
        fp8_meta[index_in_buffer] = []
        for forward in (True, False):
297
298
299
300
301
            fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
            if fp8_meta_tensor_key not in fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue

302
            key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
303
304
305
306
307
308
309
310

            if key not in cls.global_amax_buffer:
                cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
                cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
                cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
            else:
                cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
                cls.global_amax_history_buffer[key].append(
311
312
                    fp8_meta[fp8_meta_tensor_key].amax_history
                )
313
314
315
                cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
            fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
            fp8_meta[index_in_buffer].append(key)
316
317
318
319
320
321
322
323
324
325
326

    @classmethod
    def is_fp8_enabled(cls) -> bool:
        """Is FP8 enabled"""
        return cls.FP8_ENABLED

    @classmethod
    def is_fp8_calibration(cls) -> bool:
        """Is FP8 calibration"""
        return cls.FP8_CALIBRATION

327
328
329
330
331
    @classmethod
    def with_fp8_parameters(cls) -> bool:
        """Should the parameters be stored as FP8"""
        return cls.FP8_PARAMETERS

332
333
334
335
336
    @classmethod
    def with_high_precision_init_val(cls) -> bool:
        """Should the high precision initial values be stored with FP8 parameters"""
        return cls.HIGH_PRECISION_INIT_VAL

337
338
339
340
341
    @classmethod
    def fp8_graph_capturing(cls) -> bool:
        """Is CUDA graph capture under way?"""
        return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()

342
343
344
345
346
347
348
349
350
351
    @classmethod
    def is_first_fp8_module(cls):
        """Returns `True` only the first time when called multiple
        times from within the same `fp8_autocast` context.
        """
        tmp = cls.IS_FIRST_FP8_MODULE
        cls.IS_FIRST_FP8_MODULE = False
        return tmp

    @classmethod
352
    def get_fp8_recipe(cls) -> Recipe:
353
        """Return the fp8 recipe"""
354
355
356
        if cls.FP8_RECIPE is not None:
            return cls.FP8_RECIPE
        return get_default_fp8_recipe()
357
358
359
360
361
362
363

    @classmethod
    def get_fp8_group(cls) -> Union[dist_group_type, None]:
        """Return the fp8 group for scale/amax comm"""
        return cls.FP8_DISTRIBUTED_GROUP

    @classmethod
364
    def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]:
365
366
367
368
369
370
        """FP8 autocast state getter"""
        return (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
371
            cls.IS_FIRST_FP8_MODULE,
372
373
            cls.FP8_GRAPH_CAPTURING,
        )
374
375
376

    @classmethod
    def set_fp8_autocast_state(
377
        cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
378
379
    ) -> None:
        """FP8 autocast state setter"""
380
381
382
383
384
385
386
387
        (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
            cls.IS_FIRST_FP8_MODULE,
            cls.FP8_GRAPH_CAPTURING,
        ) = fp8_state
388
389

    @staticmethod
390
    def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None:
391
392
        """Reduce tensor across given group."""
        if torch.distributed.is_initialized():
393
            torch.distributed.all_reduce(
394
395
396
                tensor,
                op=torch.distributed.ReduceOp.MAX,
                group=group,
397
                async_op=False,
398
            )
399

400
    @classmethod
401
    def reduce_and_update_fp8_tensors(
402
403
404
        cls,
        forward: bool = True,
    ) -> None:
405
406
        """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
        # global_amax_buffer should only be non-empty for fp8 delayed scaling
407
408
        for buffer_key, amax_buffer in cls.global_amax_buffer.items():
            # Check for forward or backward reduction.
409
            fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
410
411
412
413
414
415
416
417
418
419
            if fwd_update != forward:
                continue
            if len(amax_buffer) == 0:
                continue

            # Retrieve autocast specific args and concat amaxes.
            recipe, group = cls.autocast_arguments[autocast_key]
            contiguous_amax = torch.cat(amax_buffer)

            # Reduction.
420
421
            if (
                recipe.reduce_amax
422
                and torch.distributed.is_initialized()
423
424
                and torch.distributed.get_world_size(group=group) > 1
            ):
425
426
427
                cls.reduce_tensor_across_group_op_max(contiguous_amax, group)

            # Amax and scale update.
428
429
430
431
432
            unfused_update = (
                bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
                or callable(recipe.amax_compute_algo)
                or callable(recipe.scaling_factor_compute_algo)
            )
433
434
435
436
437
438
439
440
441
442

            if not unfused_update:
                tex.fused_amax_and_scale_update_after_reduction(
                    contiguous_amax,
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                    recipe.amax_compute_algo,
                    get_fp8_te_dtype(recipe, forward),
                    recipe.margin,
                )
443
            else:
444
                split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
445

446
                for amax_history, scale in zip(
447
448
449
450
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                ):
                    _amax_and_scale_update(
451
                        amax_history, scale, get_fp8_max(recipe, forward), recipe
452
                    )
453

454
455
456
    @classmethod
    def get_unique_autocast_key(
        cls,
457
        recipe: Optional[Recipe] = None,
458
459
460
461
462
463
464
        group: Optional[dist_group_type] = None,
    ):
        """
        For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
        Safely using `hash` as we never cross checkpoint boundaries.
        """
        return f"{str(recipe)}:{hash(group)}"
Przemek Tredak's avatar
Przemek Tredak committed
465

466
467
468
469
470
    @classmethod
    def fp8_autocast_enter(
        cls,
        enabled: bool = False,
        calibrating: bool = False,
471
        fp8_recipe: Optional[Recipe] = None,
472
        fp8_group: Optional[dist_group_type] = None,
473
        _graph: bool = False,
474
475
    ) -> None:
        """Set state and tracking variables for entry into FP8 region."""
476
477
478
479
480

        fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)

481
482
        cls.FP8_ENABLED = enabled
        cls.FP8_CALIBRATION = calibrating
483
        cls.FP8_RECIPE = fp8_recipe
484
        cls.FP8_DISTRIBUTED_GROUP = fp8_group
485
        cls.FP8_GRAPH_CAPTURING = _graph
486
487
488
489

        if cls.FP8_AUTOCAST_DEPTH == 0:
            cls.IS_FIRST_FP8_MODULE = True
        cls.FP8_AUTOCAST_DEPTH += 1
Przemek Tredak's avatar
Przemek Tredak committed
490

491
492
493
        if enabled:
            fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
            assert fp8_available, reason_for_no_fp8
494
495
496
            if isinstance(fp8_recipe, MXFP8BlockScaling):
                mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
                assert mxfp8_available, reason_for_no_mxfp8
497
498
499
            if isinstance(fp8_recipe, Float8BlockScaling):
                fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
                assert fp8_block_available, reason_for_no_fp8_block
Przemek Tredak's avatar
Przemek Tredak committed
500

501
    @classmethod
502
    def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
503
504
        """Set state and tracking variables for exit from FP8 region."""
        cls.FP8_AUTOCAST_DEPTH -= 1
505
506
507
508
        # Reduce only the non-FP8 weight modules here.
        # FP8 weight modules are reduced at the end of the optimizer
        # step after the weight amax is populated.
        if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
509
510
            # delayed scaling only function, for other recipes (current scaling with any granularity),
            # this is noop for other recipes because cls.global_amax_buffer is empty list
511
            cls.reduce_and_update_fp8_tensors(forward=True)
512
513
514
515
516
517

    @classmethod
    def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Copy the scaling factors and amaxes for recompute forward phase
        to ensure both forward steps are numerically same.
        """
518

519
520
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
521
522
            return

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"

        to_copy = [
            fp8_meta["scaling_fwd"].amax_history.clone(),
            fp8_meta["scaling_fwd"].scale.clone(),
        ]

        if buffer_position_key in fp8_meta:
            cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
        else:
            if len(cls.fp8_tensors_recompute_buffer) == 0:
                cls.fp8_tensors_recompute_buffer = [deque()]
            else:
                cls.fp8_tensors_recompute_buffer.append(deque())
            cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
            fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1

    @classmethod
    def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Switch to the copied scaling factors and amaxes from phase
        1 forward for indentical numerical outputs.
        """
545
546
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
547
548
            return

549
        # Store updated amaxes and scales from phase 1 post forward.
550
551
        fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone()
        fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone()
552
553
554

        # Retrieve stashed amaxes and scales from phase 1 pre forward.
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
555
        stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
556
557

        # Replace amaxes and scales with stashed values for phase 2 forward
558
559
        fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
        fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
560
561
562
563

    @staticmethod
    def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
        """Restore latest scaling factors and amaxes after recompute forward run."""
564
565
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
566
567
            return

568
569
        fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
        fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
Przemek Tredak's avatar
Przemek Tredak committed
570
571


572
@contextmanager
573
574
575
576
577
def fp8_model_init(
    enabled: bool = True,
    recipe: Optional[Recipe] = None,
    preserve_high_precision_init_val: bool = False,
) -> None:
578
579
580
581
582
583
584
585
586
587
    """
    Context manager for FP8 initialization of parameters.

    Example usage:

    .. code-block:: python

        with fp8_model_init(enabled=True):
            model = transformer_engine.pytorch.Linear(768, 768)

588
589
590
591
592
593
        # Preserving high precision initial value to initialize master weight
        with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
            model = transformer_engine.pytorch.Linear(768, 768)
        master_weight = model.weight.get_high_precision_init_val()
        model.weight.clear_high_precision_init_val()

594
595
596
597
598
599
600
601
602
603
604
605
606
    Parameters
    ----------
    enabled: bool, default = `True`
             when enabled, Transformer Engine modules created inside this `fp8_model_init`
             region will hold only FP8 copies of its parameters, as opposed to the default
             behavior where both higher precision and FP8 copies are present. Setting this
             option to `True` may result in lower memory consumption and is especially
             useful for scenarios like:

             * full model training using optimizer with master weights, where the high
               precision copies of weights are already present in the optimizer.
             * inference, where only the FP8 copies of the parameters are used.
             * LoRA-like fine-tuning, where the main parameters of the model do not change.
607
608
    recipe: transformer_engine.common.recipe.Recipe, default = `None`
            Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
609
610
611
612
613
614
615
616
    preserve_high_precision_init_val: bool, default = `False`
             when enabled, store the high precision tensor used to initialize FP8 parameters
             in CPU memory, and add two function attributes named `get_high_precision_init_val()`
             and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
             precision tensor. The purpose is that users can use this high-precision copy
             to initialize master weights, avoiding the loss of precision that can occur when
             using FP8 parameters directly. Note that after the master weights are initialized,
             users should call `clear_high_precision_init_val()` to release this CPU memory.
617
618
619

             This functionality is *EXPERIMENTAL*.
    """
620
    _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
621
    _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
622
    _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
623
    FP8GlobalStateManager.FP8_PARAMETERS = enabled
624
    FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
625
    FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
626
627
628
    try:
        yield
    finally:
629
        FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
630
        FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
631
        FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val
632
633


Przemek Tredak's avatar
Przemek Tredak committed
634
635
@contextmanager
def fp8_autocast(
636
    enabled: bool = True,
schetlur-nv's avatar
schetlur-nv committed
637
    calibrating: bool = False,
638
    fp8_recipe: Optional[Recipe] = None,
Przemek Tredak's avatar
Przemek Tredak committed
639
    fp8_group: Optional[dist_group_type] = None,
640
    _graph: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
) -> None:
    """
    Context manager for FP8 usage.

    .. code-block:: python

        with fp8_autocast(enabled=True):
            out = model(inp)

    .. note::

        Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
        with shapes where both dimensions are divisible by 16. In terms of the input to the full
        Transformer network, this typically requires padding sequence length to be multiple of 16.

656
657
658
659
660
661
662
663
    .. note::

        When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
        inside a single `fp8_autocast` region. This is unsupported behavior because the amax
        reduction is handled during the exit of the `fp8_autocast` context. Calling the same
        module more than once inside an `fp8_autocast` region overrides the amax tensors
        before reduction can occur.

Przemek Tredak's avatar
Przemek Tredak committed
664
665
    Parameters
    ----------
666
    enabled: bool, default = `True`
Przemek Tredak's avatar
Przemek Tredak committed
667
             whether or not to enable fp8
668
669
670
671
672
    calibrating: bool, default = `False`
                 calibration mode allows collecting statistics such as amax and scale
                 data of fp8 tensors even when executing without fp8 enabled. This is
                 useful for saving an inference ready fp8 checkpoint while training
                 using a higher precision.
673
    fp8_recipe: recipe.Recipe, default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
674
675
676
677
678
                recipe used for FP8 training.
    fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
               distributed group over which amaxes for the fp8 tensors
               are reduced at the end of each training step.
    """
679
680
    if enabled:
        check_recipe_support(fp8_recipe)
681
    fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
682
683
684
685
686
687
688
    FP8GlobalStateManager.fp8_autocast_enter(
        enabled=enabled,
        calibrating=calibrating,
        fp8_recipe=fp8_recipe,
        fp8_group=fp8_group,
        _graph=_graph,
    )
Przemek Tredak's avatar
Przemek Tredak committed
689
690
691
    try:
        yield
    finally:
692
        FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
693
        FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
Przemek Tredak's avatar
Przemek Tredak committed
694
695


696
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
697
    """Update amax history and set next amax to zero."""
698
    if amax_history.shape[0] > 1:
699
700
        new_amax_history = torch.roll(amax_history, -1, 0)
        amax_history.copy_(new_amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
701
702
703
704
    amax_history[0].fill_(0.0)
    return amax_history


705
@torch.jit.script
706
def _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
707
708
709
710
711
712
713
    amax_history: torch.Tensor,
    amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Default function to obtain amax from history."""
    if amax_compute_algo == "max":
        amax = torch.max(amax_history, dim=0).values
    else:  # amax_compute_algo == "most_recent"
714
        amax = amax_history[0].clone()
Przemek Tredak's avatar
Przemek Tredak committed
715

716
    amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
717
718
719
    return amax_history, amax


720
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
721
722
723
724
725
def _default_sf_compute(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    margin: int,
726
    _fp32_max: float = torch.finfo(torch.float32).max,  # finfo not available in jitter
Przemek Tredak's avatar
Przemek Tredak committed
727
) -> torch.Tensor:
728
729
730
731
732
733
734
735
736
737
738
739
    """Default function to convert amax to scaling factor.
    Computing the scaling factor requires consideration of the following scenarios:
    1. amax == 0:
       No action is possible, set scale to the previous scale (or 1).
    2. 0 < amax < tiny_amax
       The amax is too tiny that the scale becomes infinite in FP32.
       Set scale = FP32_max
    3. tiny_amax <= amax < FP32_max:
       Set scale = FP8_max (or scaled_max) / amax
    4. When amax == inf or amax == nan:
       No action is possible, set scale to the previous scale (or 1).
    """
740
    sf = (fp8_max / amax) / (2**margin)
Przemek Tredak's avatar
Przemek Tredak committed
741
742
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)
743
    sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
744
745
    scale.copy_(sf)
    return scale
746

Przemek Tredak's avatar
Przemek Tredak committed
747

748
def _compute_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
749
    amax_history: torch.Tensor,
750
    amax_compute_algo: Union[Callable, str],
Przemek Tredak's avatar
Przemek Tredak committed
751
752
753
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Obtain the amax from the history."""

754
755
    if callable(amax_compute_algo):
        amax = amax_compute_algo(amax_history)
756
        amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
757
        return amax_history, amax
758
    return _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
759
        amax_history,
760
        amax_compute_algo,
Przemek Tredak's avatar
Przemek Tredak committed
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    )


def _compute_scaling_factor(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
) -> torch.Tensor:
    """Convert amax to scaling factor."""

    if recipe.scaling_factor_compute_algo is None:
        return _default_sf_compute(
            amax,
            scale,
            fp8_max,
            recipe.margin,
        )
    return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)


782
783
784
785
786
def _amax_and_scale_update(
    amax_history: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
Przemek Tredak's avatar
Przemek Tredak committed
787
) -> None:
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    """Updates FP8 meta tensors."""
    new_amax_history, amax = _compute_amax_and_update_history(
        amax_history,
        recipe.amax_compute_algo,
    )
    new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
    scale.copy_(new_scale)
    amax_history.copy_(new_amax_history)


def split_and_copy(
    buffer: torch.Tensor,
    outputs: List[torch.Tensor],
    chunk_sizes: List[int],
) -> None:
    """Split `buffer` by `chunk_sizes` and copy into `outputs`."""
    splits = buffer.split(chunk_sizes)
    torch._foreach_copy_(outputs, splits)
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851


class RecipeState(abc.ABC):
    """Configuration and state for a quantization recipe.

    This is a builder class for quantizers, which are in turn builder
    classes for quantized tensors.

    This class may pack together the state for multiple quantizers,
    which is helpful for applying fused kernels with less overhead.

    """

    @staticmethod
    def create(
        recipe: Recipe,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> RecipeState:
        """Factory method to create the state for a quantization recipe

        Parameters
        ----------
        recipe: Recipe
            Quantization recipe.
        mode: {"forward", "backward"}
            Training stage where quantization will be performed.
        num_quantizers: int, default = 1
            Number of quantizers to create state for.
        device: torch.device, default = default CUDA device
            Device for quantized tensors.

        Returns
        -------
        RecipeState:
            Quantization recipe state.

        """

        cls = None
        if recipe.delayed():
            cls = DelayedScalingRecipeState
        elif recipe.mxfp8():
            cls = MXFP8BlockScalingRecipeState
852
853
        elif recipe.float8_current_scaling():
            cls = Float8CurrentScalingRecipeState
854
855
        elif recipe.float8_block_scaling():
            cls = Float8BlockScalingRecipeState
856
        else:
857
            raise ValueError(f"{recipe.__class__.__name__} is not supported")
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
        return cls(
            recipe,
            mode=mode,
            num_quantizers=num_quantizers,
            device=device,
        )

    @abc.abstractmethod
    def make_quantizers(self) -> list:
        """Convert recipe state to quantizers.

        Quantizers are builder classes for quantized tensors. They are
        typically used to convert a high-precision tensor (e.g. in
        FP32 or BF16) into a quantized tensor (e.g. in FP8).

        """


class DelayedScalingRecipeState(RecipeState):
    """State for FP8 quantization with per-tensor delayed scaling.

    Delayed scaling recipe requires a scaling factor (applied when
    casting to FP8) and a history of max-abs values ("amax") from
    recent FP8 casts for updating the scaling factor. The scale update
    is handled externally by `FP8GlobalStateManager`.

    """

    recipe: DelayedScaling
    mode: str
    dtype: tex.DType
    scale: torch.Tensor
    amax_history: torch.Tensor

    def __init__(
        self,
        recipe: DelayedScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device)
        self.amax_history = torch.zeros(
            recipe.amax_history_len,
            num_quantizers,
            dtype=torch.float32,
            device=device,
        )

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.float8_tensor import Float8Quantizer

        return [
            Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype)
            for i in range(self.num_quantizers)
        ]


926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
class Float8CurrentScalingRecipeState(RecipeState):
    """Configuration for Per-tensor current scaling quantization.

    Per-tensor current quantization does not require state.

    """

    recipe: Float8CurrentScaling
    mode: str
    dtype: tex.DType
    device: torch.device

    def __init__(
        self,
        recipe: Float8CurrentScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.device = device

    def make_quantizers(self) -> list:
        from .tensor.float8_tensor import Float8CurrentScalingQuantizer

        return [
            Float8CurrentScalingQuantizer(self.dtype, device=self.device)
            for i in range(self.num_quantizers)
        ]


965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
class MXFP8BlockScalingRecipeState(RecipeState):
    """Configuration for MXFP8 quantization.

    MXFP8 quantization does not require state.

    """

    recipe: MXFP8BlockScaling
    mode: str
    dtype: tex.DType

    def __init__(
        self,
        recipe: MXFP8BlockScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.mxfp8_tensor import MXFP8Quantizer

        return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102


class Float8BlockScalingRecipeState(RecipeState):
    """Configuration for Float8BlockScaling quantization.

    Float8BlockScaling quantization does not require state,
    but different quantizers use different modes.
    """

    recipe: Float8BlockScaling
    mode: str
    qx_dtype: tex.DType
    qw_dtype: tex.DType
    qgrad_dtype: tex.DType

    def __init__(
        self,
        recipe: Float8BlockScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.qx_dtype = get_fp8_te_dtype(recipe, True)
        self.qw_dtype = get_fp8_te_dtype(recipe, True)
        self.qgrad_dtype = get_fp8_te_dtype(recipe, False)

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.device = device

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.float8_blockwise_tensor import Float8BlockQuantizer

        if self.mode == "forward":
            # The index convention (coming from base.py set_meta_tensor)
            # is somewhat awkward, and doesn't play nicely with QuantizeOp,
            # which is not associated with a GEMM.
            assert self.num_quantizers % 3 == 0  # x, w, output per gemm
            return list(
                itertools.chain.from_iterable(
                    [
                        [
                            Float8BlockQuantizer(
                                fp8_dtype=self.qx_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
                                block_scaling_dim=self.recipe.x_block_scaling_dim,
                            ),
                            Float8BlockQuantizer(
                                fp8_dtype=self.qw_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
                                block_scaling_dim=self.recipe.w_block_scaling_dim,
                            ),
                            Float8BlockQuantizer(
                                fp8_dtype=self.qx_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
                                block_scaling_dim=self.recipe.x_block_scaling_dim,
                            ),
                        ]
                        for _ in range(self.num_quantizers // 3)
                    ]
                )
            )

        assert self.mode == "backward", f"Unexpected mode {self.mode}"
        assert self.num_quantizers % 2 == 0  # grad_output and grad_input per gemm
        return list(
            itertools.chain.from_iterable(
                [
                    [
                        Float8BlockQuantizer(
                            fp8_dtype=self.qgrad_dtype,
                            rowwise=True,
                            columnwise=True,
                            amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
                            force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
                            block_scaling_dim=self.recipe.grad_block_scaling_dim,
                        ),
                        Float8BlockQuantizer(
                            fp8_dtype=self.qgrad_dtype,
                            rowwise=True,
                            columnwise=True,
                            amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
                            force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
                            block_scaling_dim=self.recipe.grad_block_scaling_dim,
                        ),
                    ]
                    for _ in range(self.num_quantizers // 2)
                ]
            )
        )