op.py 25.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""Base classes for fusible operations."""

from __future__ import annotations
import abc
from collections.abc import Iterable
import dataclasses
11
import pickle
12
13
14
15
from typing import Any, Optional

import torch

16
17
18
19
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
20
    Float8BlockScalingRecipeState,
21
    FP8GlobalStateManager,
22
    RecipeState,
23
    fp8_autocast,
24
)
25
from ..tensor import Quantizer
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


@dataclasses.dataclass
class OperationContext:
    """State needed to apply an operation

    Saves state from forward pass for use in backward pass.

    """

    # Tensors that have been saved from forward function
    # Note: Available in the backward function, matching tensors from
    # to_save.
    saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None
    # Tensors to save for backward function
    # Note: Expected to be set in the forward function, either
    # directly or with save_for_backward.
    to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None

    # Corresponding range in pipeline's list of saved tensors
    _saved_tensors_range: Optional[tuple[int, int]] = None

    # Whether backward pass is required
Tim Moon's avatar
Tim Moon committed
49
    requires_grad: bool = True
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
        """Register tensors to be saved for the backward function

        Expected to be called in the forward function.

        """
        self.to_save = tensors


class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
    """Tensor operation supported by the operation fuser"""

    @property
    @abc.abstractmethod
    def is_fused_op(self) -> bool:
        """Whether this op is the fusion of one or more basic ops"""

68
69
70
71
72
    def pre_first_forward(
        self,
        *,
        recipe: Optional[Recipe],
    ) -> None:
73
74
        """Preprocessing before forward pass"""

75
76
77
78
79
80
    def get_input_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized input tensor"""

    def get_grad_input_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized input's grad tensor"""

81
82
83
84
    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
85
86
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
87
88
        prev_op_grad_input_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
89
        basic_op_kwargs: list[dict[str, Any]],
90
    ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
91
92
93
94
95
96
97
98
99
100
101
        """Forward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        forward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
102
            Contexts for basic operations
103
104
        input_: torch.Tensor
            Input tensor
105
106
        basic_op_extra_inputs: list of torch.Tensor
            Extra tensor inputs to basic operations
107
108
109
110
        prev_op_grad_input_quantizer: Quantizer, optional
            The grad_input_quantizer of the preceeding operation
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
111
        basic_op_kwargs: list of dict
112
113
            Keyword arguments to forward functions of basic
            operations.
114
115
116

        Returns
        -------
117
118
119
120
        torch.Tensor:
            Output tensor.
        Iterable of torch.Tensor:
            Extra tensor outputs from basic operations.
121
122
123
124
125
126
127
128
129
130

        """
        raise NotImplementedError(
            f"Forward pass is not implemented for operation ({self.__class__.__name__})"
        )

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
131
132
133
134
135
136
137
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        Iterable[Iterable[Optional[torch.Tensor]]],
        Iterable[Iterable[Optional[torch.Tensor]]],
    ]:
138
139
140
141
142
143
144
145
146
147
148
        """Backward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        backward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
149
            Contexts for basic operations
150
        grad_output: torch.Tensor
151
152
153
154
            Loss gradient w.r.t. operation output
        basic_op_grad_extra_outputs: list of tuple of torch.Tensor
            Loss gradients w.r.t. extra tensor outputs from basic
            operations.
155
156
157
158
159
160

        Returns
        -------
        torch.Tensor:
            Loss gradient w.r.t. operation input
        Iterable of iterable of torch.Tensor:
161
162
163
            Loss gradients w.r.t. parameters for basic operations
        Iterable of iterable of torch.Tensor:
            Loss gradients w.r.t. extra tensor inputs to basic
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            operations

        """
        raise NotImplementedError(
            f"Backward pass is not implemented for operation ({self.__class__.__name__})"
        )


class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
    """Single tensor operation supported by the operation fuser

    This class holds parameters and state, even if the actual forward
    and backward passes are performed by a fused operation.

    """

180
181
182
183
184
    # Number of extra tensor inputs
    num_extra_inputs: int = 0
    # Number of extra tensor outputs
    num_extra_outputs: int = 0

185
186
187
    def __init__(self) -> None:
        super().__init__()

188
189
        # Objects for quantization
        self._quantizers: Optional[dict[str, list[Quantizer]]] = None
190
191
192
193
194
195
        self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None

    @property
    def is_fused_op(self) -> bool:
        return False

196
    def num_quantizers(
197
198
199
        self,
        mode: str,  # pylint: disable=unused-argument
    ) -> int:
200
201
202
        """Number of quantizers

        Matches number of quantized tensors used in operation.
203
204
205

        Parameters
        ----------
206
207
        mode: {"forward", "backward"}
            Quantizer type
208
209
210
211

        """
        return 0

212
213
214
215
216
217
218
219
220
221
    def get_input_quantizer(self) -> Optional[Quantizer]:
        if self.num_quantizers("forward") > 0:
            return self.get_quantizer("forward", 0)
        return None

    def get_grad_input_quantizer(self) -> Optional[Quantizer]:
        if self.num_quantizers("backward") > 0:
            return self.get_quantizer("backward", 0)
        return None

222
223
224
    def _reset_quantization_recipe_state(
        self,
        *,
225
        recipe: Recipe,
226
227
228
229
230
231
232
233
234
235
236
    ) -> None:
        """Construct state for quantization recipe"""

        # Quantization recipe state for forward and backward pass
        self._fp8_metas = {"forward": None, "backward": None}
        self._quantizers = {"forward": [], "backward": []}
        for mode in ("forward", "backward"):
            num_quantizers = self.num_quantizers(mode)
            if num_quantizers == 0:
                continue

237
238
239
240
241
            if recipe.float8_block_scaling():
                raise NotImplementedError(
                    "Fusible operations do not support FP8 block scaling recipe"
                )

242
243
244
245
246
            # Construct quantization recipe state
            recipe_state = RecipeState.create(
                recipe,
                mode=mode,
                num_quantizers=num_quantizers,
247
            )
248
249
250
251
252
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            self._fp8_metas[mode] = {
                fp8_meta_key: recipe_state,
253
                "recipe": recipe,
254
                "fp8_group": FP8GlobalStateManager.get_fp8_group(),
255
256
            }

257
258
259
260
261
            # Construct builder class for quantized tensors
            self._quantizers[mode] = recipe_state.make_quantizers()

    def _update_quantization_recipe_state(
        self,
262
        *,
263
        recipe: Recipe,
264
    ) -> None:
265
        """Make sure quantizer state matches quantization recipe"""
266

267
268
269
270
271
272
273
274
275
        # Reset quantization state if needed
        if self._fp8_metas is None or self._quantizers is None:
            self._reset_quantization_recipe_state(recipe=recipe)
            return
        for mode in ("forward", "backward"):
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
276
                continue
277
278
            recipe_state = self._fp8_metas[mode][fp8_meta_key]
            need_to_reset_recipe_state = (
279
280
281
282
283
284
285
                (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
                or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
                or (
                    recipe.float8_block_scaling()
                    and not isinstance(recipe_state, Float8BlockScalingRecipeState)
                )
            )
286
287
288
289
290
291
292
293
            if need_to_reset_recipe_state:
                self._reset_quantization_recipe_state(recipe=recipe)
                return

        # Quantization recipe state for forward and backward pass
        for mode in ("forward", "backward"):
            num_quantizers = self.num_quantizers(mode)
            if num_quantizers == 0:
294
                continue
295

296
297
298
299
            # Update FP8 metadata
            fp8_meta = self._fp8_metas[mode]
            fp8_meta["recipe"] = recipe
            fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
300

301
302
303
304
305
306
307
            # Get recipe state
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            recipe_state = fp8_meta[fp8_meta_key]

            # Reallocate amax history if needed
308
            if not recipe.delayed():
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                continue

            current_length = recipe_state.amax_history.size(0)
            target_length = recipe.amax_history_len
            if current_length != target_length:
                with torch.no_grad():
                    if target_length < current_length:
                        recipe_state.amax_history = recipe_state.amax_history[
                            :target_length
                        ].clone()
                    else:
                        recipe_state.amax_history = torch.nn.functional.pad(
                            recipe_state.amax_history,
                            pad=(0, 0, 0, target_length - current_length),
                        )
                self._quantizers[mode] = recipe_state.make_quantizers()

    def get_quantizer(
        self,
        mode: str,
        index: int,
    ) -> Quantizer:
        """Get builder class for quantized tensor
332
333
334

        Parameters
        ----------
335
336
        mode: {"forward", "backward"}
            Quantizer type
337
338

        """
339
        if self._quantizers is None:
340
            self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
341
        return self._quantizers[mode][index]
342

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    @torch.no_grad()
    def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
        """Create copies of tensors in FP8 metadata

        Tensor copies can be loaded with _load_fp8_metas.

        """
        if self._fp8_metas is None:
            return None
        out = {}
        for mode, fp8_meta in self._fp8_metas.items():
            if fp8_meta is None:
                continue
            out[mode] = {}
            for is_forward in (True, False):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
                if fp8_meta_key not in fp8_meta:
                    continue
                out[mode][fp8_meta_key] = (
                    fp8_meta[fp8_meta_key].scale.clone(),
                    fp8_meta[fp8_meta_key].amax_history.clone(),
                )
        return out

    @torch.no_grad()
    def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
        """Update FP8 metadata with saved tensor copies

        Tensor copies should be generated with _save_fp8_metas.

        """
        assert (self._fp8_metas is None) == (
            fp8_metas is None
        ), "Saved FP8 metadata does not match operation's FP8 metadata"
        if fp8_metas is None:
            return
        for mode, fp8_meta in fp8_metas.items():
            assert (
                mode in self._fp8_metas
            ), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
            for fp8_meta_key, tensors in fp8_meta.items():
                assert (
                    fp8_meta_key in self._fp8_metas[mode]
                ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
387
                scale, amax_history = tensors
388
389
390
                self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
                self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)

391
    def pre_first_forward(
392
393
        self,
        *,
394
        recipe: Optional[Recipe],
395
    ) -> None:
396
397
398
        """Preprocessing before forward pass"""

        # Initialize FP8 metadata if needed
399
400
        if recipe is not None:
            self._update_quantization_recipe_state(recipe=recipe)
401
            if not FP8GlobalStateManager.fp8_graph_capturing():
402
                if self.num_quantizers("forward"):
403
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
404
                        self._fp8_metas["forward"],
405
                    )
406
                if self.num_quantizers("backward"):
407
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
408
                        self._fp8_metas["backward"],
409
410
411
412
413
414
415
                    )

    @abc.abstractmethod
    def op_forward(
        self,
        ctx: OperationContext,
        input_: torch.Tensor,
416
        *,
417
418
        prev_op_grad_input_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
419
420
421
422
423
424
425
426
427
428
        **kwargs: Any,
    ) -> torch.Tensor:
        """Forward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        input_: torch.Tensor
            Input tensor
429
430
431
432
        prev_op_grad_input_quantizer: Quantizer, optional
            The grad_input_quantizer of the preceeding operation
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468

        Returns
        -------
        torch.Tensor:
            Output tensor

        """

    @abc.abstractmethod
    def op_backward(
        self,
        ctx: OperationContext,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
        """Backward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        grad_output: torch.Tensor
            Loss gradient w.r.t. operation output

        Returns
        -------
        torch.Tensor
            Loss gradient w.r.t. operation input
        Iterable of torch.Tensor:
            Loss gradients w.r.t. parameters

        """

    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
469
470
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
471
472
        prev_op_grad_input_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
473
        basic_op_kwargs: list[dict[str, Any]],
474
475
476
477
478
479
480
481
482
    ) -> tuple[torch.Tensor, list[tuple[()]]]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_forward` instead of `op_forward`."
            )
        output = self.op_forward(
483
484
            basic_op_ctxs[0],
            input_,
485
486
            prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
            next_op_input_quantizer=next_op_input_quantizer,
487
488
            **basic_op_kwargs[0],
        )
489
        return output, [()]
490
491
492
493
494

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        list[Iterable[Optional[torch.Tensor]]],
        list[tuple[()]],
    ]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_backward` instead of `op_backward`."
            )
509
        grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
510
        return grad_input, [grad_params], [()]
511
512
513
514

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
515
        *extra_inputs: torch.Tensor,
516
        **kwargs: Any,
517
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
518
519
520
        """Apply operation"""
        from .fuser import OperationFuser

521
522
523
        with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
        recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
        return OperationFuser([self], fuse_ops=False, recipe=recipe)(
524
525
526
527
            input,
            *extra_inputs,
            basic_op_kwargs=[kwargs],
        )
528

529
    def get_extra_state(self) -> torch.Tensor:
530
531
        """Serialize extra state

532
        Contains metadata for quantization recipe.
533
534
535
536
537
538
539

        """

        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
540
        #     We have experienced problems (e.g. in ONNX export) with
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

564
        # Store quantizer state if needed
565
        state = {}
566
        for mode in ("forward", "backward"):
567

568
569
            # Skip if op has no quantizer state
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
570
                continue
571
572
573

            # Quantizer state
            fp8_meta = self._fp8_metas[mode]
574
            state[mode] = {}
575
            state[mode]["recipe"] = fp8_meta["recipe"]
576

577
578
579
580
581
582
583
584
            # Copy tensors to CPU and store
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
                    state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
                if mode == "backward":
                    state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
                    state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603

            # Store other picklable items
            extra = {}
            for key, val in fp8_meta.items():
                if key == "buffer_index_and_autocast_key":
                    continue
                if not isinstance(val, (bool, int, float, str, tuple, list)):
                    continue
                extra[key] = val
            state[mode]["extra_fp8_variables"] = extra

        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
        return state_serialized

    def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
        """Load extra state"""
604
        if state is None or state.numel() == 0:
605
606
607
608
            return

        # Deserialize state from byte tensor
        state = pickle.loads(state.detach().numpy(force=True).tobytes())
609
        if state is None or len(state) == 0:
610
611
612
613
614
615
616
617
618
619
620
621
622
            return

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            if src.size() != dst.size():
                dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
            dst.copy_(src, non_blocking=True)

623
        # Load quantizer state if needed
624
        for mode in ("forward", "backward"):
625

626
            # Skip if checkpoint has no quantizer state
627
628
629
            if mode not in state:
                continue

630
631
632
            # Get op's quantizer state, initializing if needed
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
                with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
633
                    self._reset_quantization_recipe_state(recipe=state[mode]["recipe"])
634
635
636
637
            fp8_meta = self._fp8_metas[mode]

            # Load extra items
            fp8_meta["recipe"] = state[mode]["recipe"]
638
639
640
641
642
            fp8_meta.update(state[mode]["extra_fp8_variables"])
            if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
                del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

            # Load tensors
643
644
645
646
647
648
649
650
651
652
653
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
                    )
                if mode == "backward":
                    copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
                    )
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

        # Finish CPU-GPU memory transfers
        torch.cuda.synchronize()

    def _load_from_state_dict(self, *args, **kwargs) -> None:
        """Load state"""

        # In the base PyTorch module class, the extra state is loaded
        # _after_ the parameters. However, copying values into FP8
        # parameters requires an FP8 cast, which uses a scaling factor
        # from the operation's FP8 metadata. The FP8 metadata is
        # included in the operation's extra state, so we need to
        # manually load the extra state before loading parameters.

        state_dict, prefix = args[0], args[1]
        extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
        if extra_state_key in state_dict:
            self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(*args, **kwargs)

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707

class FusedOperation(FusibleOperation):
    """Compound tensor operation supported by the operation fuser

    If the forward or backward passes are defined, they must be
    functionally equivalent to the forward/backward passes of the
    corresponding basic ops. This class should hold no parameters or
    other state, but should access them from the basic ops.

    Parameters
    ----------
    basic_ops: iterable of FusibleOperation
        Basic ops that are interchangeable with this op

    """

    def __init__(
        self,
        basic_ops: Iterable[FusibleOperation],
    ) -> None:
        super().__init__()

        # Basic operations that comprise this fused operation
        self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops)
        if len(self.basic_ops) == 0:
            raise ValueError(
                "Attempted to construct a fused operation "
                "without specifying its corresponding basic operations"
            )

    @property
    def is_fused_op(self) -> bool:
        return True

708
709
710
711
712
713
    def get_input_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[0].get_input_quantizer()

    def get_grad_input_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[-1].get_grad_input_quantizer()

714
    def pre_first_forward(self, *args, **kwargs) -> None:
715
716
        """Preprocessing before forward pass"""
        for op in self.basic_ops:
717
            op.pre_first_forward(*args, **kwargs)
718
719
720
721

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
722
        *extra_inputs: torch.Tensor,
723
724
725
726
727
728
729
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
    ) -> torch.Tensor:
        """Apply operation"""
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
        from .fuser import OperationFuser

730
731
732
        with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
        recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
        return OperationFuser([self], fuse_ops=False, recipe=recipe)(
733
734
735
736
            input,
            *extra_inputs,
            basic_op_kwargs=basic_op_kwargs,
        )