fp8.py 39.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""FP8 utilities for TransformerEngine"""
6
7
8
from __future__ import annotations

import abc
9
import itertools
Sangkug Lym's avatar
Sangkug Lym committed
10
import os
Przemek Tredak's avatar
Przemek Tredak committed
11
from contextlib import contextmanager
12
from collections import deque
13
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
14
15

import torch
16
import transformer_engine_torch as tex
17
18
19
20
21
22
from transformer_engine.common.recipe import (
    Recipe,
    DelayedScaling,
    Format,
    MXFP8BlockScaling,
    Float8CurrentScaling,
23
    Float8BlockScaling,
24
)
Przemek Tredak's avatar
Przemek Tredak committed
25
26

from .constants import dist_group_type
27
from .utils import get_device_compute_capability
28
from .jit import jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
29

30

31
__all__ = ["fp8_autocast", "fp8_model_init"]
32
33
34


def check_fp8_support() -> Tuple[bool, str]:
35
    """Return if fp8 support is available"""
36
    if get_device_compute_capability() >= (9, 0):  # hopper and above
37
        return True, ""
38
    if get_device_compute_capability() < (8, 9):  # pre-ada
39
40
41
42
43
44
45
46
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if tex.get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if float(torch.version.cuda) < 12.1:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


47
48
def check_mxfp8_support() -> Tuple[bool, str]:
    """Return if fp8 support is available"""
49
50
    if get_device_compute_capability() >= (12, 0):
        return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
51
52
53
54
55
    if get_device_compute_capability() >= (10, 0):  # blackwell and above
        return True, ""
    return False, "Device compute capability 10.0 or higher required for MXFP8 execution."


56
57
58
59
60
61
62
63
64
65
66
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
    """Return if fp8 block scaling support is available"""
    if (
        get_device_compute_capability() >= (9, 0)
        and get_device_compute_capability() < (10, 0)
        and float(torch.version.cuda) >= 12.9
    ):
        return True, ""
    return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."


67
68
69
70
71
72
73
74
75
76
77
78
79
def check_recipe_support(recipe: Recipe) -> None:
    """Check if the given recipe is supported."""
    recipe_supported = True
    unsupported_reason = ""
    if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
        recipe_supported, unsupported_reason = check_fp8_support()
    elif isinstance(recipe, Float8BlockScaling):
        recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
    elif isinstance(recipe, MXFP8BlockScaling):
        recipe_supported, unsupported_reason = check_mxfp8_support()
    assert recipe_supported, unsupported_reason


80
def get_default_fp8_recipe() -> Recipe:
81
    """FP8 recipe with default args."""
82
    if check_mxfp8_support()[0]:
83
        return MXFP8BlockScaling()
84
85
86
    if get_device_compute_capability() >= (12, 0):
        # This is a temporary restriction until MXFP8 is supported for all gemm layouts.
        return Float8CurrentScaling()
87
    return DelayedScaling()
88
89


90
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
91
92
93
94
95
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return torch.float8_e4m3fn
96
    return torch.float8_e5m2
97
98


99
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
100
101
102
103
104
105
    """Get fp8 data type according to recipe and tensor"""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return tex.DType.kFloat8E4M3
    return tex.DType.kFloat8E5M2
106
107


108
def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
109
110
111
112
113
114
115
116
    """Get max representible FP8 value."""
    if fp8_recipe.fp8_format == Format.E4M3 or (
        fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
    ):
        return Format.E4M3.value.max_fwd
    return Format.E5M2.value.max_fwd


117
118
119
class FP8GlobalStateManager:
    """Class to keep track of and manipulate the global
    FP8 state at different stages of execution.
120
    """
121

122
123
124
125
    FP8_ENABLED = False
    FP8_CALIBRATION = False
    FP8_RECIPE = None
    FP8_DISTRIBUTED_GROUP = None
126
    FP8_PARAMETERS = False
127
    HIGH_PRECISION_INIT_VAL = False
128
    IS_FIRST_FP8_MODULE = False
129
    FP8_GRAPH_CAPTURING = False
130
    FP8_AUTOCAST_DEPTH = 0
131
132
133
    global_amax_buffer = {}
    global_amax_history_buffer = {}
    global_scale_buffer = {}
134
135
136
    fp8_tensors_recompute_buffer = []
    fp8_available = None
    reason_for_no_fp8 = ""
137
138
139
140
    autocast_arguments = {}
    autocast_to_fp8_params = {}
    fp8_param_to_autocast = {}
    skip_fp8_weight_update_tensor = None
141
142
    mxfp8_available = None
    reason_for_no_mxfp8 = ""
143
144
    fp8_block_scaling_available = None
    reason_for_no_fp8_block_scaling = None
145

146
147
148
149
150
151
152
    @classmethod
    def reset(cls) -> None:
        """Reset the global state"""
        cls.FP8_ENABLED = False
        cls.FP8_CALIBRATION = False
        cls.FP8_RECIPE = None
        cls.FP8_DISTRIBUTED_GROUP = None
153
        cls.FP8_PARAMETERS = False
154
        cls.HIGH_PRECISION_INIT_VAL = False
155
        cls.IS_FIRST_FP8_MODULE = False
156
        cls.FP8_GRAPH_CAPTURING = False
157
        cls.FP8_AUTOCAST_DEPTH = 0
158
159
160
        cls.global_amax_buffer = {}
        cls.global_amax_history_buffer = {}
        cls.global_scale_buffer = {}
161
162
163
        cls.fp8_tensors_recompute_buffer = []
        cls.fp8_available = None
        cls.reason_for_no_fp8 = ""
164
        cls.autocast_arguments = {}
165
166
        cls.autocast_to_fp8_params = {}
        cls.fp8_param_to_autocast = {}
167
        cls.skip_fp8_weight_update_tensor = None
168
169
        cls.mxfp8_available = None
        cls.reason_for_no_mxfp8 = ""
170
171
        cls.fp8_block_scaling_available = None
        cls.reason_for_no_fp8_block_scaling = ""
172
173
174
175
176
177
178
179
180
181
182
183

    @classmethod
    def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
        """`skip_fp8_weight_update_tensor` inplace setter."""
        if cls.skip_fp8_weight_update_tensor is None:
            cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
        cls.skip_fp8_weight_update_tensor.fill_(skip)

    @classmethod
    def get_skip_fp8_weight_update_tensor(cls) -> None:
        """`skip_fp8_weight_update_tensor` getter."""
        return cls.skip_fp8_weight_update_tensor
184

185
186
187
188
189
190
191
    @classmethod
    def is_fp8_available(cls) -> Tuple[bool, str]:
        """Return if fp8 support is available"""
        if cls.fp8_available is None:
            cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
        return cls.fp8_available, cls.reason_for_no_fp8

192
193
194
195
196
197
198
    @classmethod
    def is_mxfp8_available(cls) -> Tuple[bool, str]:
        """Return if MXFP8/current scaling support is available."""
        if cls.mxfp8_available is None:
            cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
        return cls.mxfp8_available, cls.reason_for_no_mxfp8

199
200
201
202
203
204
205
206
207
    @classmethod
    def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
        """Return if Float8 block scaling support is available."""
        if cls.fp8_block_scaling_available is None:
            cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = (
                check_fp8_block_scaling_support()
            )
        return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling

208
209
210
211
212
213
214
215
    @staticmethod
    def get_meta_tensor_key(forward: bool = True) -> str:
        """Returns scaling key in `fp8_meta`."""
        if forward:
            return "scaling_fwd"
        return "scaling_bwd"

    @staticmethod
216
217
218
    def get_fwd_bwd_key(forward: bool = True) -> str:
        """Convert bool `forward` to string."""
        return "forward" if forward else "backward"
219
220

    @classmethod
221
222
223
224
225
226
    def get_buffer_info(cls) -> str:
        """
        Returns a key for `fp8_meta` that stores the module's index
        in the global buffers along with autocast information.
        """
        return "buffer_index_and_autocast_key"
227
228

    @classmethod
229
230
231
    def get_key_in_buffer(
        cls,
        forward: bool,
232
        fp8_recipe: Recipe,
233
234
235
236
237
        fp8_group: dist_group_type,
    ) -> str:
        """Returns a key into the global FP8 buffers."""
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        fwd_bwd_key = cls.get_fwd_bwd_key(forward)
238
        return f"{fwd_bwd_key}_{autocast_key}"
239
240

    @classmethod
241
    def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
242
        """Splits buffer key into relevant parts."""
243
        forward, autocast_key = key.split("_", 1)
244
        forward = forward == "forward"
245
        return forward, autocast_key
246
247

    @classmethod
248
249
250
    def add_fp8_tensors_to_global_buffer(
        cls,
        fp8_meta: Dict[str, Any],
251
    ) -> None:
252
        """
253
254
        Delayed scaling only.

255
256
257
258
259
260
261
262
263
264
265
266
        The amax reduction process happens completely outside the FP8 modules.
        To participate in the reduction, the only role played by a module is
        to call this function in order to append it's FP8 tensor into a global
        buffer. There are 5 global buffers maintained, one each for amax, amax
        history, scale, scale-inverse, and non-weight-mask. Each buffer has
        keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
        to indicate the type of FP8 tensor, since the forward and backward
        reductions happen separately.

        Note: For CG capture, this method is called from the graphed
        wrapper. For non CG case, it's called from within the module.
        """
267

268
269
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
270
271
            return

272
273
274
275
276
        # Every module must call this function exactly once since
        # the amax tensors are static. Ensures that compatibility
        # with non-graphed modules is maintained.
        index_in_buffer = cls.get_buffer_info()  # Same index for fwd/bwd fp8 tensors.
        if index_in_buffer in fp8_meta:
277
278
            return

279
280
        fp8_meta[index_in_buffer] = []
        for forward in (True, False):
281
282
283
284
285
            fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
            if fp8_meta_tensor_key not in fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue

286
            key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
287
288
289
290
291
292
293
294

            if key not in cls.global_amax_buffer:
                cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
                cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
                cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
            else:
                cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
                cls.global_amax_history_buffer[key].append(
295
296
                    fp8_meta[fp8_meta_tensor_key].amax_history
                )
297
298
299
                cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
            fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
            fp8_meta[index_in_buffer].append(key)
300
301
302
303
304
305
306
307
308
309
310

    @classmethod
    def is_fp8_enabled(cls) -> bool:
        """Is FP8 enabled"""
        return cls.FP8_ENABLED

    @classmethod
    def is_fp8_calibration(cls) -> bool:
        """Is FP8 calibration"""
        return cls.FP8_CALIBRATION

311
312
313
314
315
    @classmethod
    def with_fp8_parameters(cls) -> bool:
        """Should the parameters be stored as FP8"""
        return cls.FP8_PARAMETERS

316
317
318
319
320
    @classmethod
    def with_high_precision_init_val(cls) -> bool:
        """Should the high precision initial values be stored with FP8 parameters"""
        return cls.HIGH_PRECISION_INIT_VAL

321
322
323
324
325
    @classmethod
    def fp8_graph_capturing(cls) -> bool:
        """Is CUDA graph capture under way?"""
        return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()

326
327
328
329
330
331
332
333
334
335
    @classmethod
    def is_first_fp8_module(cls):
        """Returns `True` only the first time when called multiple
        times from within the same `fp8_autocast` context.
        """
        tmp = cls.IS_FIRST_FP8_MODULE
        cls.IS_FIRST_FP8_MODULE = False
        return tmp

    @classmethod
336
    def get_fp8_recipe(cls) -> Recipe:
337
        """Return the fp8 recipe"""
338
339
340
        if cls.FP8_RECIPE is not None:
            return cls.FP8_RECIPE
        return get_default_fp8_recipe()
341
342
343
344
345
346
347

    @classmethod
    def get_fp8_group(cls) -> Union[dist_group_type, None]:
        """Return the fp8 group for scale/amax comm"""
        return cls.FP8_DISTRIBUTED_GROUP

    @classmethod
348
    def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]:
349
350
351
352
353
354
        """FP8 autocast state getter"""
        return (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
355
            cls.IS_FIRST_FP8_MODULE,
356
357
            cls.FP8_GRAPH_CAPTURING,
        )
358
359
360

    @classmethod
    def set_fp8_autocast_state(
361
        cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
362
363
    ) -> None:
        """FP8 autocast state setter"""
364
365
366
367
368
369
370
371
        (
            cls.FP8_ENABLED,
            cls.FP8_CALIBRATION,
            cls.FP8_RECIPE,
            cls.FP8_DISTRIBUTED_GROUP,
            cls.IS_FIRST_FP8_MODULE,
            cls.FP8_GRAPH_CAPTURING,
        ) = fp8_state
372
373

    @staticmethod
374
    def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None:
375
376
        """Reduce tensor across given group."""
        if torch.distributed.is_initialized():
377
            torch.distributed.all_reduce(
378
379
380
                tensor,
                op=torch.distributed.ReduceOp.MAX,
                group=group,
381
                async_op=False,
382
            )
383

384
    @classmethod
385
    def reduce_and_update_fp8_tensors(
386
387
388
        cls,
        forward: bool = True,
    ) -> None:
389
390
        """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
        # global_amax_buffer should only be non-empty for fp8 delayed scaling
391
392
        for buffer_key, amax_buffer in cls.global_amax_buffer.items():
            # Check for forward or backward reduction.
393
            fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
394
395
396
397
398
399
400
401
402
403
            if fwd_update != forward:
                continue
            if len(amax_buffer) == 0:
                continue

            # Retrieve autocast specific args and concat amaxes.
            recipe, group = cls.autocast_arguments[autocast_key]
            contiguous_amax = torch.cat(amax_buffer)

            # Reduction.
404
405
            if (
                recipe.reduce_amax
406
                and torch.distributed.is_initialized()
407
408
                and torch.distributed.get_world_size(group=group) > 1
            ):
409
410
411
                cls.reduce_tensor_across_group_op_max(contiguous_amax, group)

            # Amax and scale update.
412
413
414
415
416
            unfused_update = (
                bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
                or callable(recipe.amax_compute_algo)
                or callable(recipe.scaling_factor_compute_algo)
            )
417
418
419
420
421
422
423
424
425
426

            if not unfused_update:
                tex.fused_amax_and_scale_update_after_reduction(
                    contiguous_amax,
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                    recipe.amax_compute_algo,
                    get_fp8_te_dtype(recipe, forward),
                    recipe.margin,
                )
427
            else:
428
                split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
429

430
                for amax_history, scale in zip(
431
432
433
434
                    cls.global_amax_history_buffer[buffer_key],
                    cls.global_scale_buffer[buffer_key],
                ):
                    _amax_and_scale_update(
435
                        amax_history, scale, get_fp8_max(recipe, forward), recipe
436
                    )
437

438
439
440
    @classmethod
    def get_unique_autocast_key(
        cls,
441
        recipe: Optional[Recipe] = None,
442
443
444
445
446
447
448
        group: Optional[dist_group_type] = None,
    ):
        """
        For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
        Safely using `hash` as we never cross checkpoint boundaries.
        """
        return f"{str(recipe)}:{hash(group)}"
Przemek Tredak's avatar
Przemek Tredak committed
449

450
451
452
453
454
    @classmethod
    def fp8_autocast_enter(
        cls,
        enabled: bool = False,
        calibrating: bool = False,
455
        fp8_recipe: Optional[Recipe] = None,
456
        fp8_group: Optional[dist_group_type] = None,
457
        _graph: bool = False,
458
459
    ) -> None:
        """Set state and tracking variables for entry into FP8 region."""
460
461
462
463
464

        fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
        autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
        cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)

465
466
        cls.FP8_ENABLED = enabled
        cls.FP8_CALIBRATION = calibrating
467
        cls.FP8_RECIPE = fp8_recipe
468
        cls.FP8_DISTRIBUTED_GROUP = fp8_group
469
        cls.FP8_GRAPH_CAPTURING = _graph
470
471
472
473

        if cls.FP8_AUTOCAST_DEPTH == 0:
            cls.IS_FIRST_FP8_MODULE = True
        cls.FP8_AUTOCAST_DEPTH += 1
Przemek Tredak's avatar
Przemek Tredak committed
474

475
476
477
        if enabled:
            fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
            assert fp8_available, reason_for_no_fp8
478
479
480
            if isinstance(fp8_recipe, MXFP8BlockScaling):
                mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
                assert mxfp8_available, reason_for_no_mxfp8
481
482
483
            if isinstance(fp8_recipe, Float8BlockScaling):
                fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
                assert fp8_block_available, reason_for_no_fp8_block
Przemek Tredak's avatar
Przemek Tredak committed
484

485
    @classmethod
486
    def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
487
488
        """Set state and tracking variables for exit from FP8 region."""
        cls.FP8_AUTOCAST_DEPTH -= 1
489
490
491
492
        # Reduce only the non-FP8 weight modules here.
        # FP8 weight modules are reduced at the end of the optimizer
        # step after the weight amax is populated.
        if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
493
494
            # delayed scaling only function, for other recipes (current scaling with any granularity),
            # this is noop for other recipes because cls.global_amax_buffer is empty list
495
            cls.reduce_and_update_fp8_tensors(forward=True)
496
497
498
499
500
501

    @classmethod
    def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Copy the scaling factors and amaxes for recompute forward phase
        to ensure both forward steps are numerically same.
        """
502

503
504
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
505
506
            return

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"

        to_copy = [
            fp8_meta["scaling_fwd"].amax_history.clone(),
            fp8_meta["scaling_fwd"].scale.clone(),
        ]

        if buffer_position_key in fp8_meta:
            cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
        else:
            if len(cls.fp8_tensors_recompute_buffer) == 0:
                cls.fp8_tensors_recompute_buffer = [deque()]
            else:
                cls.fp8_tensors_recompute_buffer.append(deque())
            cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
            fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1

    @classmethod
    def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
        """Switch to the copied scaling factors and amaxes from phase
        1 forward for indentical numerical outputs.
        """
529
530
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
531
532
            return

533
        # Store updated amaxes and scales from phase 1 post forward.
534
535
        fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone()
        fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone()
536
537
538

        # Retrieve stashed amaxes and scales from phase 1 pre forward.
        buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
539
        stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
540
541

        # Replace amaxes and scales with stashed values for phase 2 forward
542
543
        fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
        fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
544
545
546
547

    @staticmethod
    def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
        """Restore latest scaling factors and amaxes after recompute forward run."""
548
549
        # delayed scaling only function, noop for any other recipe
        if not fp8_meta["recipe"].delayed():
550
551
            return

552
553
        fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
        fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
Przemek Tredak's avatar
Przemek Tredak committed
554
555


556
@contextmanager
557
558
559
560
561
def fp8_model_init(
    enabled: bool = True,
    recipe: Optional[Recipe] = None,
    preserve_high_precision_init_val: bool = False,
) -> None:
562
563
564
565
566
567
568
569
570
571
    """
    Context manager for FP8 initialization of parameters.

    Example usage:

    .. code-block:: python

        with fp8_model_init(enabled=True):
            model = transformer_engine.pytorch.Linear(768, 768)

572
573
574
575
576
577
        # Preserving high precision initial value to initialize master weight
        with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
            model = transformer_engine.pytorch.Linear(768, 768)
        master_weight = model.weight.get_high_precision_init_val()
        model.weight.clear_high_precision_init_val()

578
579
580
581
582
583
584
585
586
587
588
589
590
    Parameters
    ----------
    enabled: bool, default = `True`
             when enabled, Transformer Engine modules created inside this `fp8_model_init`
             region will hold only FP8 copies of its parameters, as opposed to the default
             behavior where both higher precision and FP8 copies are present. Setting this
             option to `True` may result in lower memory consumption and is especially
             useful for scenarios like:

             * full model training using optimizer with master weights, where the high
               precision copies of weights are already present in the optimizer.
             * inference, where only the FP8 copies of the parameters are used.
             * LoRA-like fine-tuning, where the main parameters of the model do not change.
591
592
    recipe: transformer_engine.common.recipe.Recipe, default = `None`
            Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
593
594
595
596
597
598
599
600
    preserve_high_precision_init_val: bool, default = `False`
             when enabled, store the high precision tensor used to initialize FP8 parameters
             in CPU memory, and add two function attributes named `get_high_precision_init_val()`
             and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
             precision tensor. The purpose is that users can use this high-precision copy
             to initialize master weights, avoiding the loss of precision that can occur when
             using FP8 parameters directly. Note that after the master weights are initialized,
             users should call `clear_high_precision_init_val()` to release this CPU memory.
601
602
603

             This functionality is *EXPERIMENTAL*.
    """
604
    _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
605
    _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
606
    _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
607
    FP8GlobalStateManager.FP8_PARAMETERS = enabled
608
    FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
609
    FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
610
611
612
    try:
        yield
    finally:
613
        FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
614
        FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
615
        FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val
616
617


Przemek Tredak's avatar
Przemek Tredak committed
618
619
@contextmanager
def fp8_autocast(
620
    enabled: bool = True,
schetlur-nv's avatar
schetlur-nv committed
621
    calibrating: bool = False,
622
    fp8_recipe: Optional[Recipe] = None,
Przemek Tredak's avatar
Przemek Tredak committed
623
    fp8_group: Optional[dist_group_type] = None,
624
    _graph: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
) -> None:
    """
    Context manager for FP8 usage.

    .. code-block:: python

        with fp8_autocast(enabled=True):
            out = model(inp)

    .. note::

        Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
        with shapes where both dimensions are divisible by 16. In terms of the input to the full
        Transformer network, this typically requires padding sequence length to be multiple of 16.

640
641
642
643
644
645
646
647
    .. note::

        When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
        inside a single `fp8_autocast` region. This is unsupported behavior because the amax
        reduction is handled during the exit of the `fp8_autocast` context. Calling the same
        module more than once inside an `fp8_autocast` region overrides the amax tensors
        before reduction can occur.

Przemek Tredak's avatar
Przemek Tredak committed
648
649
    Parameters
    ----------
650
    enabled: bool, default = `True`
Przemek Tredak's avatar
Przemek Tredak committed
651
             whether or not to enable fp8
652
653
654
655
656
    calibrating: bool, default = `False`
                 calibration mode allows collecting statistics such as amax and scale
                 data of fp8 tensors even when executing without fp8 enabled. This is
                 useful for saving an inference ready fp8 checkpoint while training
                 using a higher precision.
657
    fp8_recipe: recipe.Recipe, default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
658
659
660
661
662
                recipe used for FP8 training.
    fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
               distributed group over which amaxes for the fp8 tensors
               are reduced at the end of each training step.
    """
663
664
    if enabled:
        check_recipe_support(fp8_recipe)
665
    fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
666
667
668
669
670
671
672
    FP8GlobalStateManager.fp8_autocast_enter(
        enabled=enabled,
        calibrating=calibrating,
        fp8_recipe=fp8_recipe,
        fp8_group=fp8_group,
        _graph=_graph,
    )
Przemek Tredak's avatar
Przemek Tredak committed
673
674
675
    try:
        yield
    finally:
676
        FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
677
        FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
Przemek Tredak's avatar
Przemek Tredak committed
678
679


680
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
681
    """Update amax history and set next amax to zero."""
682
    if amax_history.shape[0] > 1:
683
684
        new_amax_history = torch.roll(amax_history, -1, 0)
        amax_history.copy_(new_amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
685
686
687
688
    amax_history[0].fill_(0.0)
    return amax_history


689
@torch.jit.script
690
def _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
691
692
693
694
695
696
697
    amax_history: torch.Tensor,
    amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Default function to obtain amax from history."""
    if amax_compute_algo == "max":
        amax = torch.max(amax_history, dim=0).values
    else:  # amax_compute_algo == "most_recent"
698
        amax = amax_history[0].clone()
Przemek Tredak's avatar
Przemek Tredak committed
699

700
    amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
701
702
703
    return amax_history, amax


704
@jit_fuser
Przemek Tredak's avatar
Przemek Tredak committed
705
706
707
708
709
def _default_sf_compute(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    margin: int,
710
    _fp32_max: float = torch.finfo(torch.float32).max,  # finfo not available in jitter
Przemek Tredak's avatar
Przemek Tredak committed
711
) -> torch.Tensor:
712
713
714
715
716
717
718
719
720
721
722
723
    """Default function to convert amax to scaling factor.
    Computing the scaling factor requires consideration of the following scenarios:
    1. amax == 0:
       No action is possible, set scale to the previous scale (or 1).
    2. 0 < amax < tiny_amax
       The amax is too tiny that the scale becomes infinite in FP32.
       Set scale = FP32_max
    3. tiny_amax <= amax < FP32_max:
       Set scale = FP8_max (or scaled_max) / amax
    4. When amax == inf or amax == nan:
       No action is possible, set scale to the previous scale (or 1).
    """
724
    sf = (fp8_max / amax) / (2**margin)
Przemek Tredak's avatar
Przemek Tredak committed
725
726
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)
727
    sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
728
729
    scale.copy_(sf)
    return scale
730

Przemek Tredak's avatar
Przemek Tredak committed
731

732
def _compute_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
733
    amax_history: torch.Tensor,
734
    amax_compute_algo: Union[Callable, str],
Przemek Tredak's avatar
Przemek Tredak committed
735
736
737
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Obtain the amax from the history."""

738
739
    if callable(amax_compute_algo):
        amax = amax_compute_algo(amax_history)
740
        amax_history = _update_amax_history(amax_history)
Przemek Tredak's avatar
Przemek Tredak committed
741
        return amax_history, amax
742
    return _default_get_amax_and_update_history(
Przemek Tredak's avatar
Przemek Tredak committed
743
        amax_history,
744
        amax_compute_algo,
Przemek Tredak's avatar
Przemek Tredak committed
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
    )


def _compute_scaling_factor(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
) -> torch.Tensor:
    """Convert amax to scaling factor."""

    if recipe.scaling_factor_compute_algo is None:
        return _default_sf_compute(
            amax,
            scale,
            fp8_max,
            recipe.margin,
        )
    return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)


766
767
768
769
770
def _amax_and_scale_update(
    amax_history: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: float,
    recipe: DelayedScaling,
Przemek Tredak's avatar
Przemek Tredak committed
771
) -> None:
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    """Updates FP8 meta tensors."""
    new_amax_history, amax = _compute_amax_and_update_history(
        amax_history,
        recipe.amax_compute_algo,
    )
    new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
    scale.copy_(new_scale)
    amax_history.copy_(new_amax_history)


def split_and_copy(
    buffer: torch.Tensor,
    outputs: List[torch.Tensor],
    chunk_sizes: List[int],
) -> None:
    """Split `buffer` by `chunk_sizes` and copy into `outputs`."""
    splits = buffer.split(chunk_sizes)
    torch._foreach_copy_(outputs, splits)
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835


class RecipeState(abc.ABC):
    """Configuration and state for a quantization recipe.

    This is a builder class for quantizers, which are in turn builder
    classes for quantized tensors.

    This class may pack together the state for multiple quantizers,
    which is helpful for applying fused kernels with less overhead.

    """

    @staticmethod
    def create(
        recipe: Recipe,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> RecipeState:
        """Factory method to create the state for a quantization recipe

        Parameters
        ----------
        recipe: Recipe
            Quantization recipe.
        mode: {"forward", "backward"}
            Training stage where quantization will be performed.
        num_quantizers: int, default = 1
            Number of quantizers to create state for.
        device: torch.device, default = default CUDA device
            Device for quantized tensors.

        Returns
        -------
        RecipeState:
            Quantization recipe state.

        """

        cls = None
        if recipe.delayed():
            cls = DelayedScalingRecipeState
        elif recipe.mxfp8():
            cls = MXFP8BlockScalingRecipeState
836
837
        elif recipe.float8_current_scaling():
            cls = Float8CurrentScalingRecipeState
838
839
        elif recipe.float8_block_scaling():
            cls = Float8BlockScalingRecipeState
840
        else:
841
            raise ValueError(f"{recipe.__class__.__name__} is not supported")
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
        return cls(
            recipe,
            mode=mode,
            num_quantizers=num_quantizers,
            device=device,
        )

    @abc.abstractmethod
    def make_quantizers(self) -> list:
        """Convert recipe state to quantizers.

        Quantizers are builder classes for quantized tensors. They are
        typically used to convert a high-precision tensor (e.g. in
        FP32 or BF16) into a quantized tensor (e.g. in FP8).

        """


class DelayedScalingRecipeState(RecipeState):
    """State for FP8 quantization with per-tensor delayed scaling.

    Delayed scaling recipe requires a scaling factor (applied when
    casting to FP8) and a history of max-abs values ("amax") from
    recent FP8 casts for updating the scaling factor. The scale update
    is handled externally by `FP8GlobalStateManager`.

    """

    recipe: DelayedScaling
    mode: str
    dtype: tex.DType
    scale: torch.Tensor
    amax_history: torch.Tensor

    def __init__(
        self,
        recipe: DelayedScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device)
        self.amax_history = torch.zeros(
            recipe.amax_history_len,
            num_quantizers,
            dtype=torch.float32,
            device=device,
        )

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.float8_tensor import Float8Quantizer

        return [
            Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype)
            for i in range(self.num_quantizers)
        ]


910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
class Float8CurrentScalingRecipeState(RecipeState):
    """Configuration for Per-tensor current scaling quantization.

    Per-tensor current quantization does not require state.

    """

    recipe: Float8CurrentScaling
    mode: str
    dtype: tex.DType
    device: torch.device

    def __init__(
        self,
        recipe: Float8CurrentScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.device = device

    def make_quantizers(self) -> list:
        from .tensor.float8_tensor import Float8CurrentScalingQuantizer

        return [
            Float8CurrentScalingQuantizer(self.dtype, device=self.device)
            for i in range(self.num_quantizers)
        ]


949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
class MXFP8BlockScalingRecipeState(RecipeState):
    """Configuration for MXFP8 quantization.

    MXFP8 quantization does not require state.

    """

    recipe: MXFP8BlockScaling
    mode: str
    dtype: tex.DType

    def __init__(
        self,
        recipe: MXFP8BlockScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.dtype = get_fp8_te_dtype(recipe, mode == "forward")

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.mxfp8_tensor import MXFP8Quantizer

        return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086


class Float8BlockScalingRecipeState(RecipeState):
    """Configuration for Float8BlockScaling quantization.

    Float8BlockScaling quantization does not require state,
    but different quantizers use different modes.
    """

    recipe: Float8BlockScaling
    mode: str
    qx_dtype: tex.DType
    qw_dtype: tex.DType
    qgrad_dtype: tex.DType

    def __init__(
        self,
        recipe: Float8BlockScaling,
        *,
        mode: str,
        num_quantizers: int = 1,
        device: Optional[torch.device] = None,
    ) -> None:
        self.recipe = recipe
        self.mode = mode
        self.num_quantizers = num_quantizers
        self.qx_dtype = get_fp8_te_dtype(recipe, True)
        self.qw_dtype = get_fp8_te_dtype(recipe, True)
        self.qgrad_dtype = get_fp8_te_dtype(recipe, False)

        # Allocate buffers
        if device is None:
            device = torch.device("cuda")
        self.device = device

    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.float8_blockwise_tensor import Float8BlockQuantizer

        if self.mode == "forward":
            # The index convention (coming from base.py set_meta_tensor)
            # is somewhat awkward, and doesn't play nicely with QuantizeOp,
            # which is not associated with a GEMM.
            assert self.num_quantizers % 3 == 0  # x, w, output per gemm
            return list(
                itertools.chain.from_iterable(
                    [
                        [
                            Float8BlockQuantizer(
                                fp8_dtype=self.qx_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
                                block_scaling_dim=self.recipe.x_block_scaling_dim,
                            ),
                            Float8BlockQuantizer(
                                fp8_dtype=self.qw_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
                                block_scaling_dim=self.recipe.w_block_scaling_dim,
                            ),
                            Float8BlockQuantizer(
                                fp8_dtype=self.qx_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
                                block_scaling_dim=self.recipe.x_block_scaling_dim,
                            ),
                        ]
                        for _ in range(self.num_quantizers // 3)
                    ]
                )
            )

        assert self.mode == "backward", f"Unexpected mode {self.mode}"
        assert self.num_quantizers % 2 == 0  # grad_output and grad_input per gemm
        return list(
            itertools.chain.from_iterable(
                [
                    [
                        Float8BlockQuantizer(
                            fp8_dtype=self.qgrad_dtype,
                            rowwise=True,
                            columnwise=True,
                            amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
                            force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
                            block_scaling_dim=self.recipe.grad_block_scaling_dim,
                        ),
                        Float8BlockQuantizer(
                            fp8_dtype=self.qgrad_dtype,
                            rowwise=True,
                            columnwise=True,
                            amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
                            force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
                            block_scaling_dim=self.recipe.grad_block_scaling_dim,
                        ),
                    ]
                    for _ in range(self.num_quantizers // 2)
                ]
            )
        )