op.py 24.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""Base classes for fusible operations."""

from __future__ import annotations
import abc
from collections.abc import Iterable
import dataclasses
11
import pickle
12
13
14
15
from typing import Any, Optional

import torch

16
17
18
19
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
20
    Float8BlockScalingRecipeState,
21
    FP8GlobalStateManager,
22
    RecipeState,
23
    fp8_autocast,
24
)
25
from ..tensor import Quantizer
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


@dataclasses.dataclass
class OperationContext:
    """State needed to apply an operation

    Saves state from forward pass for use in backward pass.

    """

    # Tensors that have been saved from forward function
    # Note: Available in the backward function, matching tensors from
    # to_save.
    saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None
    # Tensors to save for backward function
    # Note: Expected to be set in the forward function, either
    # directly or with save_for_backward.
    to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None

    # Corresponding range in pipeline's list of saved tensors
    _saved_tensors_range: Optional[tuple[int, int]] = None

    # Whether backward pass is required
Tim Moon's avatar
Tim Moon committed
49
    requires_grad: bool = True
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
        """Register tensors to be saved for the backward function

        Expected to be called in the forward function.

        """
        self.to_save = tensors


class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
    """Tensor operation supported by the operation fuser"""

    @property
    @abc.abstractmethod
    def is_fused_op(self) -> bool:
        """Whether this op is the fusion of one or more basic ops"""

    def pre_forward(self) -> None:
        """Preprocessing before forward pass"""

    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
75
76
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
77
78
79
        basic_op_prev_ops: list[Optional[BasicOperation]],
        basic_op_next_ops: list[Optional[BasicOperation]],
        basic_op_kwargs: list[dict[str, Any]],
80
    ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
81
82
83
84
85
86
87
88
89
90
91
        """Forward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        forward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
92
            Contexts for basic operations
93
94
        input_: torch.Tensor
            Input tensor
95
96
        basic_op_extra_inputs: list of torch.Tensor
            Extra tensor inputs to basic operations
97
        basic_op_prev_ops: list of BasicOperation
98
99
            Basic operations that preceed this operation's basic
            operations
100
        basic_op_next_ops: list of BasicOperation
101
102
            Basic operations that follow this operation's basic
            operations
103
        basic_op_kwargs: list of dict
104
105
            Keyword arguments to forward functions of basic
            operations.
106
107
108

        Returns
        -------
109
110
111
112
        torch.Tensor:
            Output tensor.
        Iterable of torch.Tensor:
            Extra tensor outputs from basic operations.
113
114
115
116
117
118
119
120
121
122

        """
        raise NotImplementedError(
            f"Forward pass is not implemented for operation ({self.__class__.__name__})"
        )

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
123
124
125
126
127
128
129
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        Iterable[Iterable[Optional[torch.Tensor]]],
        Iterable[Iterable[Optional[torch.Tensor]]],
    ]:
130
131
132
133
134
135
136
137
138
139
140
        """Backward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        backward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
141
            Contexts for basic operations
142
        grad_output: torch.Tensor
143
144
145
146
            Loss gradient w.r.t. operation output
        basic_op_grad_extra_outputs: list of tuple of torch.Tensor
            Loss gradients w.r.t. extra tensor outputs from basic
            operations.
147
148
149
150
151
152

        Returns
        -------
        torch.Tensor:
            Loss gradient w.r.t. operation input
        Iterable of iterable of torch.Tensor:
153
154
155
            Loss gradients w.r.t. parameters for basic operations
        Iterable of iterable of torch.Tensor:
            Loss gradients w.r.t. extra tensor inputs to basic
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            operations

        """
        raise NotImplementedError(
            f"Backward pass is not implemented for operation ({self.__class__.__name__})"
        )


class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
    """Single tensor operation supported by the operation fuser

    This class holds parameters and state, even if the actual forward
    and backward passes are performed by a fused operation.

    """

172
173
174
175
176
    # Number of extra tensor inputs
    num_extra_inputs: int = 0
    # Number of extra tensor outputs
    num_extra_outputs: int = 0

177
178
179
    def __init__(self) -> None:
        super().__init__()

180
181
        # Objects for quantization
        self._quantizers: Optional[dict[str, list[Quantizer]]] = None
182
183
184
185
186
187
        self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None

    @property
    def is_fused_op(self) -> bool:
        return False

188
    def num_quantizers(
189
190
191
        self,
        mode: str,  # pylint: disable=unused-argument
    ) -> int:
192
193
194
        """Number of quantizers

        Matches number of quantized tensors used in operation.
195
196
197

        Parameters
        ----------
198
199
        mode: {"forward", "backward"}
            Quantizer type
200
201
202
203

        """
        return 0

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def _reset_quantization_recipe_state(
        self,
        *,
        recipe: Optional[Recipe] = None,
    ) -> None:
        """Construct state for quantization recipe"""

        # Quantization recipe
        if recipe is None:
            recipe = FP8GlobalStateManager.get_fp8_recipe()

        # Quantization recipe state for forward and backward pass
        self._fp8_metas = {"forward": None, "backward": None}
        self._quantizers = {"forward": [], "backward": []}
        for mode in ("forward", "backward"):
            num_quantizers = self.num_quantizers(mode)
            if num_quantizers == 0:
                continue

223
224
225
226
227
            if recipe.float8_block_scaling():
                raise NotImplementedError(
                    "Fusible operations do not support FP8 block scaling recipe"
                )

228
229
230
231
232
            # Construct quantization recipe state
            recipe_state = RecipeState.create(
                recipe,
                mode=mode,
                num_quantizers=num_quantizers,
233
            )
234
235
236
237
238
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            self._fp8_metas[mode] = {
                fp8_meta_key: recipe_state,
239
                "recipe": recipe,
240
                "fp8_group": FP8GlobalStateManager.get_fp8_group(),
241
242
            }

243
244
245
246
247
            # Construct builder class for quantized tensors
            self._quantizers[mode] = recipe_state.make_quantizers()

    def _update_quantization_recipe_state(
        self,
248
        *,
249
        recipe: Optional[Recipe] = None,
250
    ) -> None:
251
        """Make sure quantizer state matches quantization recipe"""
252

253
254
255
        # Quantization recipe
        if recipe is None:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
256

257
258
259
260
261
262
263
264
265
        # Reset quantization state if needed
        if self._fp8_metas is None or self._quantizers is None:
            self._reset_quantization_recipe_state(recipe=recipe)
            return
        for mode in ("forward", "backward"):
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
266
                continue
267
268
            recipe_state = self._fp8_metas[mode][fp8_meta_key]
            need_to_reset_recipe_state = (
269
270
271
272
273
274
275
                (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
                or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
                or (
                    recipe.float8_block_scaling()
                    and not isinstance(recipe_state, Float8BlockScalingRecipeState)
                )
            )
276
277
278
279
280
281
282
283
            if need_to_reset_recipe_state:
                self._reset_quantization_recipe_state(recipe=recipe)
                return

        # Quantization recipe state for forward and backward pass
        for mode in ("forward", "backward"):
            num_quantizers = self.num_quantizers(mode)
            if num_quantizers == 0:
284
                continue
285

286
287
288
289
            # Update FP8 metadata
            fp8_meta = self._fp8_metas[mode]
            fp8_meta["recipe"] = recipe
            fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
290

291
292
293
294
295
296
297
            # Get recipe state
            fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                forward=(mode == "forward"),
            )
            recipe_state = fp8_meta[fp8_meta_key]

            # Reallocate amax history if needed
298
            if not recipe.delayed():
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                continue

            current_length = recipe_state.amax_history.size(0)
            target_length = recipe.amax_history_len
            if current_length != target_length:
                with torch.no_grad():
                    if target_length < current_length:
                        recipe_state.amax_history = recipe_state.amax_history[
                            :target_length
                        ].clone()
                    else:
                        recipe_state.amax_history = torch.nn.functional.pad(
                            recipe_state.amax_history,
                            pad=(0, 0, 0, target_length - current_length),
                        )
                self._quantizers[mode] = recipe_state.make_quantizers()

    def get_quantizer(
        self,
        mode: str,
        index: int,
    ) -> Quantizer:
        """Get builder class for quantized tensor
322
323
324

        Parameters
        ----------
325
326
        mode: {"forward", "backward"}
            Quantizer type
327
328

        """
329
330
331
        if self._quantizers is None:
            self._reset_quantization_recipe_state()
        return self._quantizers[mode][index]
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    @torch.no_grad()
    def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
        """Create copies of tensors in FP8 metadata

        Tensor copies can be loaded with _load_fp8_metas.

        """
        if self._fp8_metas is None:
            return None
        out = {}
        for mode, fp8_meta in self._fp8_metas.items():
            if fp8_meta is None:
                continue
            out[mode] = {}
            for is_forward in (True, False):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
                if fp8_meta_key not in fp8_meta:
                    continue
                out[mode][fp8_meta_key] = (
                    fp8_meta[fp8_meta_key].scale.clone(),
                    fp8_meta[fp8_meta_key].amax_history.clone(),
                )
        return out

    @torch.no_grad()
    def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
        """Update FP8 metadata with saved tensor copies

        Tensor copies should be generated with _save_fp8_metas.

        """
        assert (self._fp8_metas is None) == (
            fp8_metas is None
        ), "Saved FP8 metadata does not match operation's FP8 metadata"
        if fp8_metas is None:
            return
        for mode, fp8_meta in fp8_metas.items():
            assert (
                mode in self._fp8_metas
            ), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
            for fp8_meta_key, tensors in fp8_meta.items():
                assert (
                    fp8_meta_key in self._fp8_metas[mode]
                ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
377
                scale, amax_history = tensors
378
379
380
381
382
383
384
                self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
                self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)

    def pre_forward(
        self,
        *,
        fp8_enabled: Optional[bool] = None,
385
        fp8_recipe: Optional[Recipe] = None,
386
    ) -> None:
387
388
389
        """Preprocessing before forward pass"""

        # Initialize FP8 metadata if needed
390
391
        if fp8_enabled is None:
            fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
392
        if fp8_enabled:
393
            self._update_quantization_recipe_state(recipe=fp8_recipe)
394
            if not FP8GlobalStateManager.fp8_graph_capturing():
395
                if self.num_quantizers("forward"):
396
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
397
                        self._fp8_metas["forward"],
398
                    )
399
                if self.num_quantizers("backward"):
400
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
401
                        self._fp8_metas["backward"],
402
403
404
405
406
407
408
                    )

    @abc.abstractmethod
    def op_forward(
        self,
        ctx: OperationContext,
        input_: torch.Tensor,
409
        *,
410
411
412
413
414
415
416
417
418
419
420
421
        prev_op: Optional[BasicOperation] = None,
        next_op: Optional[BasicOperation] = None,
        **kwargs: Any,
    ) -> torch.Tensor:
        """Forward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        input_: torch.Tensor
            Input tensor
422
423
424
425
        prev_op: BasicOperation, optional
            Basic operation that preceeds this operation
        next_op: BasicOperation, optional
            Basic operation that follows this operation
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461

        Returns
        -------
        torch.Tensor:
            Output tensor

        """

    @abc.abstractmethod
    def op_backward(
        self,
        ctx: OperationContext,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
        """Backward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        grad_output: torch.Tensor
            Loss gradient w.r.t. operation output

        Returns
        -------
        torch.Tensor
            Loss gradient w.r.t. operation input
        Iterable of torch.Tensor:
            Loss gradients w.r.t. parameters

        """

    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
462
463
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
464
465
466
        basic_op_prev_ops: list[Optional[BasicOperation]],
        basic_op_next_ops: list[Optional[BasicOperation]],
        basic_op_kwargs: list[dict[str, Any]],
467
468
469
470
471
472
473
474
475
    ) -> tuple[torch.Tensor, list[tuple[()]]]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_forward` instead of `op_forward`."
            )
        output = self.op_forward(
476
477
            basic_op_ctxs[0],
            input_,
478
479
            prev_op=basic_op_prev_ops[0],
            next_op=basic_op_next_ops[0],
480
481
            **basic_op_kwargs[0],
        )
482
        return output, [()]
483
484
485
486
487

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        list[Iterable[Optional[torch.Tensor]]],
        list[tuple[()]],
    ]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_backward` instead of `op_backward`."
            )
502
        grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
503
        return grad_input, [grad_params], [()]
504
505
506
507

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
508
        *extra_inputs: torch.Tensor,
509
        **kwargs: Any,
510
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
511
512
513
        """Apply operation"""
        from .fuser import OperationFuser

514
515
516
517
518
        return OperationFuser([self], fuse_ops=False)(
            input,
            *extra_inputs,
            basic_op_kwargs=[kwargs],
        )
519

520
    def get_extra_state(self) -> torch.Tensor:
521
522
        """Serialize extra state

523
        Contains metadata for quantization recipe.
524
525
526
527
528
529
530

        """

        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
531
        #     We have experienced problems (e.g. in ONNX export) with
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

555
        # Store quantizer state if needed
556
        state = {}
557
        for mode in ("forward", "backward"):
558

559
560
            # Skip if op has no quantizer state
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
561
                continue
562
563
564

            # Quantizer state
            fp8_meta = self._fp8_metas[mode]
565
            state[mode] = {}
566
            state[mode]["recipe"] = fp8_meta["recipe"]
567

568
569
570
571
572
573
574
575
            # Copy tensors to CPU and store
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
                    state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
                if mode == "backward":
                    state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
                    state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594

            # Store other picklable items
            extra = {}
            for key, val in fp8_meta.items():
                if key == "buffer_index_and_autocast_key":
                    continue
                if not isinstance(val, (bool, int, float, str, tuple, list)):
                    continue
                extra[key] = val
            state[mode]["extra_fp8_variables"] = extra

        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
        return state_serialized

    def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
        """Load extra state"""
595
        if state is None or state.numel() == 0:
596
597
598
599
            return

        # Deserialize state from byte tensor
        state = pickle.loads(state.detach().numpy(force=True).tobytes())
600
        if state is None or len(state) == 0:
601
602
603
604
605
606
607
608
609
610
611
612
613
            return

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            if src.size() != dst.size():
                dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
            dst.copy_(src, non_blocking=True)

614
        # Load quantizer state if needed
615
        for mode in ("forward", "backward"):
616

617
            # Skip if checkpoint has no quantizer state
618
619
620
            if mode not in state:
                continue

621
622
623
624
625
626
627
628
            # Get op's quantizer state, initializing if needed
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
                with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
                    self._reset_quantization_recipe_state()
            fp8_meta = self._fp8_metas[mode]

            # Load extra items
            fp8_meta["recipe"] = state[mode]["recipe"]
629
630
631
632
633
            fp8_meta.update(state[mode]["extra_fp8_variables"])
            if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
                del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

            # Load tensors
634
635
636
637
638
639
640
641
642
643
644
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
                    )
                if mode == "backward":
                    copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
                    )
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664

        # Finish CPU-GPU memory transfers
        torch.cuda.synchronize()

    def _load_from_state_dict(self, *args, **kwargs) -> None:
        """Load state"""

        # In the base PyTorch module class, the extra state is loaded
        # _after_ the parameters. However, copying values into FP8
        # parameters requires an FP8 cast, which uses a scaling factor
        # from the operation's FP8 metadata. The FP8 metadata is
        # included in the operation's extra state, so we need to
        # manually load the extra state before loading parameters.

        state_dict, prefix = args[0], args[1]
        extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
        if extra_state_key in state_dict:
            self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(*args, **kwargs)

665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706

class FusedOperation(FusibleOperation):
    """Compound tensor operation supported by the operation fuser

    If the forward or backward passes are defined, they must be
    functionally equivalent to the forward/backward passes of the
    corresponding basic ops. This class should hold no parameters or
    other state, but should access them from the basic ops.

    Parameters
    ----------
    basic_ops: iterable of FusibleOperation
        Basic ops that are interchangeable with this op

    """

    def __init__(
        self,
        basic_ops: Iterable[FusibleOperation],
    ) -> None:
        super().__init__()

        # Basic operations that comprise this fused operation
        self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops)
        if len(self.basic_ops) == 0:
            raise ValueError(
                "Attempted to construct a fused operation "
                "without specifying its corresponding basic operations"
            )

    @property
    def is_fused_op(self) -> bool:
        return True

    def pre_forward(self) -> None:
        """Preprocessing before forward pass"""
        for op in self.basic_ops:
            op.pre_forward()

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
707
        *extra_inputs: torch.Tensor,
708
709
710
711
712
713
714
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
    ) -> torch.Tensor:
        """Apply operation"""
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
        from .fuser import OperationFuser

715
716
717
718
719
        return OperationFuser([self], fuse_ops=False)(
            input,
            *extra_inputs,
            basic_op_kwargs=basic_op_kwargs,
        )