op.py 26.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""Base classes for fusible operations."""

from __future__ import annotations
import abc
from collections.abc import Iterable
import dataclasses
11
import pickle
12
13
14
15
from typing import Any, Optional

import torch

16
17
18
19
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
20
    Float8BlockScalingRecipeState,
21
    FP8GlobalStateManager,
22
    RecipeState,
23
    fp8_autocast,
24
)
25
from ..tensor import Quantizer
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


@dataclasses.dataclass
class OperationContext:
    """State needed to apply an operation

    Saves state from forward pass for use in backward pass.

    """

    # Tensors that have been saved from forward function
    # Note: Available in the backward function, matching tensors from
    # to_save.
    saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None
    # Tensors to save for backward function
    # Note: Expected to be set in the forward function, either
    # directly or with save_for_backward.
    to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None

    # Corresponding range in pipeline's list of saved tensors
    _saved_tensors_range: Optional[tuple[int, int]] = None

    # Whether backward pass is required
Tim Moon's avatar
Tim Moon committed
49
    requires_grad: bool = True
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

    def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
        """Register tensors to be saved for the backward function

        Expected to be called in the forward function.

        """
        self.to_save = tensors


class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
    """Tensor operation supported by the operation fuser"""

    @property
    @abc.abstractmethod
    def is_fused_op(self) -> bool:
        """Whether this op is the fusion of one or more basic ops"""

    def pre_forward(self) -> None:
        """Preprocessing before forward pass"""

71
72
73
74
75
76
    def get_input_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized input tensor"""

    def get_grad_input_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized input's grad tensor"""

77
78
79
80
    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
81
82
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
83
84
85
        prev_op_grad_input_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
        is_first_op: bool,
86
        basic_op_kwargs: list[dict[str, Any]],
87
    ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
88
89
90
91
92
93
94
95
96
97
98
        """Forward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        forward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
99
            Contexts for basic operations
100
101
        input_: torch.Tensor
            Input tensor
102
103
        basic_op_extra_inputs: list of torch.Tensor
            Extra tensor inputs to basic operations
104
105
106
107
108
109
110
111
        prev_op_grad_input_quantizer: Quantizer, optional
            The grad_input_quantizer of the preceeding operation
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
        is_first_op: bool
            Does this op have a preceeding op or is it the first one in the
            fuser. Used in the backward pass to safely delete the saved input
            tensor when no longer needed and there is a preceeding op.
112
        basic_op_kwargs: list of dict
113
114
            Keyword arguments to forward functions of basic
            operations.
115
116
117

        Returns
        -------
118
119
120
121
        torch.Tensor:
            Output tensor.
        Iterable of torch.Tensor:
            Extra tensor outputs from basic operations.
122
123
124
125
126
127
128
129
130
131

        """
        raise NotImplementedError(
            f"Forward pass is not implemented for operation ({self.__class__.__name__})"
        )

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
132
133
134
135
136
137
138
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        Iterable[Iterable[Optional[torch.Tensor]]],
        Iterable[Iterable[Optional[torch.Tensor]]],
    ]:
139
140
141
142
143
144
145
146
147
148
149
        """Backward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        backward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
150
            Contexts for basic operations
151
        grad_output: torch.Tensor
152
153
154
155
            Loss gradient w.r.t. operation output
        basic_op_grad_extra_outputs: list of tuple of torch.Tensor
            Loss gradients w.r.t. extra tensor outputs from basic
            operations.
156
157
158
159
160
161

        Returns
        -------
        torch.Tensor:
            Loss gradient w.r.t. operation input
        Iterable of iterable of torch.Tensor:
162
163
164
            Loss gradients w.r.t. parameters for basic operations
        Iterable of iterable of torch.Tensor:
            Loss gradients w.r.t. extra tensor inputs to basic
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            operations

        """
        raise NotImplementedError(
            f"Backward pass is not implemented for operation ({self.__class__.__name__})"
        )


class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
    """Single tensor operation supported by the operation fuser

    This class holds parameters and state, even if the actual forward
    and backward passes are performed by a fused operation.

    """

181
182
183
184
185
    # Number of extra tensor inputs
    num_extra_inputs: int = 0
    # Number of extra tensor outputs
    num_extra_outputs: int = 0

186
187
188
    def __init__(self) -> None:
        super().__init__()

189
190
        # Objects for quantization
        self._quantizers: Optional[dict[str, list[Quantizer]]] = None
191
192
193
194
195
196
        self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None

    @property
    def is_fused_op(self) -> bool:
        return False

197
    def num_quantizers(
198
199
200
        self,
        mode: str,  # pylint: disable=unused-argument
    ) -> int:
201
202
203
        """Number of quantizers

        Matches number of quantized tensors used in operation.
204
205
206

        Parameters
        ----------
207
208
        mode: {"forward", "backward"}
            Quantizer type
209
210
211
212

        """
        return 0

213
214
215
216
217
218
219
220
221
222
    def get_input_quantizer(self) -> Optional[Quantizer]:
        if self.num_quantizers("forward") > 0:
            return self.get_quantizer("forward", 0)
        return None

    def get_grad_input_quantizer(self) -> Optional[Quantizer]:
        if self.num_quantizers("backward") > 0:
            return self.get_quantizer("backward", 0)
        return None

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def _reset_quantization_recipe_state(
        self,
        *,
        recipe: Optional[Recipe] = None,
    ) -> None:
        """Construct state for quantization recipe"""

        # Quantization recipe
        if recipe is None:
            recipe = FP8GlobalStateManager.get_fp8_recipe()

        # Quantization recipe state for forward and backward pass
        self._fp8_metas = {"forward": None, "backward": None}
        self._quantizers = {"forward": [], "backward": []}
        for mode in ("forward", "backward"):
            num_quantizers = self.num_quantizers(mode)
            if num_quantizers == 0:
                continue

242
243
244
245
246
            if recipe.float8_block_scaling():
                raise NotImplementedError(
                    "Fusible operations do not support FP8 block scaling recipe"
                )

247
248
249
250
251
            # Construct quantization recipe state
            recipe_state = RecipeState.create(
                recipe,
                mode=mode,
                num_quantizers=num_quantizers,
252
            )
253
254
255
256
257
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            self._fp8_metas[mode] = {
                fp8_meta_key: recipe_state,
258
                "recipe": recipe,
259
                "fp8_group": FP8GlobalStateManager.get_fp8_group(),
260
261
            }

262
263
264
265
266
            # Construct builder class for quantized tensors
            self._quantizers[mode] = recipe_state.make_quantizers()

    def _update_quantization_recipe_state(
        self,
267
        *,
268
        recipe: Optional[Recipe] = None,
269
    ) -> None:
270
        """Make sure quantizer state matches quantization recipe"""
271

272
273
274
        # Quantization recipe
        if recipe is None:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
275

276
277
278
279
280
281
282
283
284
        # Reset quantization state if needed
        if self._fp8_metas is None or self._quantizers is None:
            self._reset_quantization_recipe_state(recipe=recipe)
            return
        for mode in ("forward", "backward"):
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
285
                continue
286
287
            recipe_state = self._fp8_metas[mode][fp8_meta_key]
            need_to_reset_recipe_state = (
288
289
290
291
292
293
294
                (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
                or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
                or (
                    recipe.float8_block_scaling()
                    and not isinstance(recipe_state, Float8BlockScalingRecipeState)
                )
            )
295
296
297
298
299
300
301
302
            if need_to_reset_recipe_state:
                self._reset_quantization_recipe_state(recipe=recipe)
                return

        # Quantization recipe state for forward and backward pass
        for mode in ("forward", "backward"):
            num_quantizers = self.num_quantizers(mode)
            if num_quantizers == 0:
303
                continue
304

305
306
307
308
            # Update FP8 metadata
            fp8_meta = self._fp8_metas[mode]
            fp8_meta["recipe"] = recipe
            fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
309

310
311
312
313
314
315
316
            # Get recipe state
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            recipe_state = fp8_meta[fp8_meta_key]

            # Reallocate amax history if needed
317
            if not recipe.delayed():
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
                continue

            current_length = recipe_state.amax_history.size(0)
            target_length = recipe.amax_history_len
            if current_length != target_length:
                with torch.no_grad():
                    if target_length < current_length:
                        recipe_state.amax_history = recipe_state.amax_history[
                            :target_length
                        ].clone()
                    else:
                        recipe_state.amax_history = torch.nn.functional.pad(
                            recipe_state.amax_history,
                            pad=(0, 0, 0, target_length - current_length),
                        )
                self._quantizers[mode] = recipe_state.make_quantizers()

    def get_quantizer(
        self,
        mode: str,
        index: int,
    ) -> Quantizer:
        """Get builder class for quantized tensor
341
342
343

        Parameters
        ----------
344
345
        mode: {"forward", "backward"}
            Quantizer type
346
347

        """
348
349
350
        if self._quantizers is None:
            self._reset_quantization_recipe_state()
        return self._quantizers[mode][index]
351

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    @torch.no_grad()
    def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
        """Create copies of tensors in FP8 metadata

        Tensor copies can be loaded with _load_fp8_metas.

        """
        if self._fp8_metas is None:
            return None
        out = {}
        for mode, fp8_meta in self._fp8_metas.items():
            if fp8_meta is None:
                continue
            out[mode] = {}
            for is_forward in (True, False):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
                if fp8_meta_key not in fp8_meta:
                    continue
                out[mode][fp8_meta_key] = (
                    fp8_meta[fp8_meta_key].scale.clone(),
                    fp8_meta[fp8_meta_key].amax_history.clone(),
                )
        return out

    @torch.no_grad()
    def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
        """Update FP8 metadata with saved tensor copies

        Tensor copies should be generated with _save_fp8_metas.

        """
        assert (self._fp8_metas is None) == (
            fp8_metas is None
        ), "Saved FP8 metadata does not match operation's FP8 metadata"
        if fp8_metas is None:
            return
        for mode, fp8_meta in fp8_metas.items():
            assert (
                mode in self._fp8_metas
            ), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
            for fp8_meta_key, tensors in fp8_meta.items():
                assert (
                    fp8_meta_key in self._fp8_metas[mode]
                ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
396
                scale, amax_history = tensors
397
398
399
400
401
402
403
                self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
                self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)

    def pre_forward(
        self,
        *,
        fp8_enabled: Optional[bool] = None,
404
        fp8_recipe: Optional[Recipe] = None,
405
    ) -> None:
406
407
408
        """Preprocessing before forward pass"""

        # Initialize FP8 metadata if needed
409
410
        if fp8_enabled is None:
            fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
411
        if fp8_enabled:
412
            self._update_quantization_recipe_state(recipe=fp8_recipe)
413
            if not FP8GlobalStateManager.fp8_graph_capturing():
414
                if self.num_quantizers("forward"):
415
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
416
                        self._fp8_metas["forward"],
417
                    )
418
                if self.num_quantizers("backward"):
419
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
420
                        self._fp8_metas["backward"],
421
422
423
424
425
426
427
                    )

    @abc.abstractmethod
    def op_forward(
        self,
        ctx: OperationContext,
        input_: torch.Tensor,
428
        *,
429
430
431
        prev_op_grad_input_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
        is_first_op: bool,
432
433
434
435
436
437
438
439
440
441
        **kwargs: Any,
    ) -> torch.Tensor:
        """Forward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        input_: torch.Tensor
            Input tensor
442
443
444
445
446
447
448
449
        prev_op_grad_input_quantizer: Quantizer, optional
            The grad_input_quantizer of the preceeding operation
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
        is_first_op: bool
            Does this op have a preceeding op or is it the first one in the
            fuser. Used in the backward pass to safely delete the saved input
            tensor when no longer needed and there is a preceeding op.
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485

        Returns
        -------
        torch.Tensor:
            Output tensor

        """

    @abc.abstractmethod
    def op_backward(
        self,
        ctx: OperationContext,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
        """Backward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        grad_output: torch.Tensor
            Loss gradient w.r.t. operation output

        Returns
        -------
        torch.Tensor
            Loss gradient w.r.t. operation input
        Iterable of torch.Tensor:
            Loss gradients w.r.t. parameters

        """

    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
486
487
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
488
489
490
        prev_op_grad_input_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
        is_first_op: bool,
491
        basic_op_kwargs: list[dict[str, Any]],
492
493
494
495
496
497
498
499
500
    ) -> tuple[torch.Tensor, list[tuple[()]]]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_forward` instead of `op_forward`."
            )
        output = self.op_forward(
501
502
            basic_op_ctxs[0],
            input_,
503
504
505
            prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
            next_op_input_quantizer=next_op_input_quantizer,
            is_first_op=is_first_op,
506
507
            **basic_op_kwargs[0],
        )
508
        return output, [()]
509
510
511
512
513

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        list[Iterable[Optional[torch.Tensor]]],
        list[tuple[()]],
    ]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_backward` instead of `op_backward`."
            )
528
        grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
529
        return grad_input, [grad_params], [()]
530
531
532
533

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
534
        *extra_inputs: torch.Tensor,
535
        **kwargs: Any,
536
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
537
538
539
        """Apply operation"""
        from .fuser import OperationFuser

540
541
542
543
544
        return OperationFuser([self], fuse_ops=False)(
            input,
            *extra_inputs,
            basic_op_kwargs=[kwargs],
        )
545

546
    def get_extra_state(self) -> torch.Tensor:
547
548
        """Serialize extra state

549
        Contains metadata for quantization recipe.
550
551
552
553
554
555
556

        """

        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
557
        #     We have experienced problems (e.g. in ONNX export) with
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

581
        # Store quantizer state if needed
582
        state = {}
583
        for mode in ("forward", "backward"):
584

585
586
            # Skip if op has no quantizer state
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
587
                continue
588
589
590

            # Quantizer state
            fp8_meta = self._fp8_metas[mode]
591
            state[mode] = {}
592
            state[mode]["recipe"] = fp8_meta["recipe"]
593

594
595
596
597
598
599
600
601
            # Copy tensors to CPU and store
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
                    state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
                if mode == "backward":
                    state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
                    state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

            # Store other picklable items
            extra = {}
            for key, val in fp8_meta.items():
                if key == "buffer_index_and_autocast_key":
                    continue
                if not isinstance(val, (bool, int, float, str, tuple, list)):
                    continue
                extra[key] = val
            state[mode]["extra_fp8_variables"] = extra

        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
        return state_serialized

    def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
        """Load extra state"""
621
        if state is None or state.numel() == 0:
622
623
624
625
            return

        # Deserialize state from byte tensor
        state = pickle.loads(state.detach().numpy(force=True).tobytes())
626
        if state is None or len(state) == 0:
627
628
629
630
631
632
633
634
635
636
637
638
639
            return

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            if src.size() != dst.size():
                dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
            dst.copy_(src, non_blocking=True)

640
        # Load quantizer state if needed
641
        for mode in ("forward", "backward"):
642

643
            # Skip if checkpoint has no quantizer state
644
645
646
            if mode not in state:
                continue

647
648
649
650
651
652
653
654
            # Get op's quantizer state, initializing if needed
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
                with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
                    self._reset_quantization_recipe_state()
            fp8_meta = self._fp8_metas[mode]

            # Load extra items
            fp8_meta["recipe"] = state[mode]["recipe"]
655
656
657
658
659
            fp8_meta.update(state[mode]["extra_fp8_variables"])
            if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
                del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

            # Load tensors
660
661
662
663
664
665
666
667
668
669
670
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
                    )
                if mode == "backward":
                    copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
                    )
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690

        # Finish CPU-GPU memory transfers
        torch.cuda.synchronize()

    def _load_from_state_dict(self, *args, **kwargs) -> None:
        """Load state"""

        # In the base PyTorch module class, the extra state is loaded
        # _after_ the parameters. However, copying values into FP8
        # parameters requires an FP8 cast, which uses a scaling factor
        # from the operation's FP8 metadata. The FP8 metadata is
        # included in the operation's extra state, so we need to
        # manually load the extra state before loading parameters.

        state_dict, prefix = args[0], args[1]
        extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
        if extra_state_key in state_dict:
            self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(*args, **kwargs)

691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724

class FusedOperation(FusibleOperation):
    """Compound tensor operation supported by the operation fuser

    If the forward or backward passes are defined, they must be
    functionally equivalent to the forward/backward passes of the
    corresponding basic ops. This class should hold no parameters or
    other state, but should access them from the basic ops.

    Parameters
    ----------
    basic_ops: iterable of FusibleOperation
        Basic ops that are interchangeable with this op

    """

    def __init__(
        self,
        basic_ops: Iterable[FusibleOperation],
    ) -> None:
        super().__init__()

        # Basic operations that comprise this fused operation
        self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops)
        if len(self.basic_ops) == 0:
            raise ValueError(
                "Attempted to construct a fused operation "
                "without specifying its corresponding basic operations"
            )

    @property
    def is_fused_op(self) -> bool:
        return True

725
726
727
728
729
730
    def get_input_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[0].get_input_quantizer()

    def get_grad_input_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[-1].get_grad_input_quantizer()

731
732
733
734
735
736
737
738
    def pre_forward(self) -> None:
        """Preprocessing before forward pass"""
        for op in self.basic_ops:
            op.pre_forward()

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
739
        *extra_inputs: torch.Tensor,
740
741
742
743
744
745
746
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
    ) -> torch.Tensor:
        """Apply operation"""
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
        from .fuser import OperationFuser

747
748
749
750
751
        return OperationFuser([self], fuse_ops=False)(
            input,
            *extra_inputs,
            basic_op_kwargs=basic_op_kwargs,
        )