op.py 26.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""Base classes for fusible operations."""

from __future__ import annotations
import abc
from collections.abc import Iterable
import dataclasses
11
import pickle
12
13
14
15
from typing import Any, Optional

import torch

16
from transformer_engine.common.recipe import Recipe
17
from ..quantization import (
18
    FP8GlobalStateManager,
19
    RecipeState,
20
    autocast,
21
)
22
from ..tensor import Quantizer
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


@dataclasses.dataclass
class OperationContext:
    """State needed to apply an operation

    Saves state from forward pass for use in backward pass.

    """

    # Tensors that have been saved from forward function
    # Note: Available in the backward function, matching tensors from
    # to_save.
    saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None
    # Tensors to save for backward function
    # Note: Expected to be set in the forward function, either
    # directly or with save_for_backward.
    to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None

    # Corresponding range in pipeline's list of saved tensors
    _saved_tensors_range: Optional[tuple[int, int]] = None

    # Whether backward pass is required
Tim Moon's avatar
Tim Moon committed
46
    requires_grad: bool = True
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
        """Register tensors to be saved for the backward function

        Expected to be called in the forward function.

        """
        self.to_save = tensors


class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
    """Tensor operation supported by the operation fuser"""

    @property
    @abc.abstractmethod
    def is_fused_op(self) -> bool:
        """Whether this op is the fusion of one or more basic ops"""

Jan Bielak's avatar
Jan Bielak committed
65
66
    def pre_first_fuser_forward(self) -> None:
        """Preprocessing before first fuser forward pass"""
67

68
69
70
71
72
73
74
    def pre_fuser_forward(
        self,
        *,
        requires_grad: bool,  # pylint: disable=unused-argument
    ) -> None:
        """Preprocessing before fuser forward pass"""

75
76
77
    def get_input_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized input tensor"""

Jan Bielak's avatar
Jan Bielak committed
78
79
    def get_grad_output_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized output's grad tensor"""
80

81
82
83
84
    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
85
86
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
Jan Bielak's avatar
Jan Bielak committed
87
        prev_op_grad_output_quantizer: Optional[Quantizer],
88
        next_op_input_quantizer: Optional[Quantizer],
89
        basic_op_kwargs: list[dict[str, Any]],
90
    ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
91
92
93
94
95
96
97
98
99
100
101
        """Forward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        forward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
102
            Contexts for basic operations
103
104
        input_: torch.Tensor
            Input tensor
105
106
        basic_op_extra_inputs: list of torch.Tensor
            Extra tensor inputs to basic operations
Jan Bielak's avatar
Jan Bielak committed
107
108
        prev_op_grad_output_quantizer: Quantizer, optional
            The grad_output_quantizer of the preceeding operation
109
110
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
111
        basic_op_kwargs: list of dict
112
113
            Keyword arguments to forward functions of basic
            operations.
114
115
116

        Returns
        -------
117
118
119
120
        torch.Tensor:
            Output tensor.
        Iterable of torch.Tensor:
            Extra tensor outputs from basic operations.
121
122
123
124
125
126
127
128
129
130

        """
        raise NotImplementedError(
            f"Forward pass is not implemented for operation ({self.__class__.__name__})"
        )

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
131
132
133
134
135
136
137
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        Iterable[Iterable[Optional[torch.Tensor]]],
        Iterable[Iterable[Optional[torch.Tensor]]],
    ]:
138
139
140
141
142
143
144
145
146
147
148
        """Backward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        backward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
149
            Contexts for basic operations
150
        grad_output: torch.Tensor
151
152
153
154
            Loss gradient w.r.t. operation output
        basic_op_grad_extra_outputs: list of tuple of torch.Tensor
            Loss gradients w.r.t. extra tensor outputs from basic
            operations.
155
156
157
158
159
160

        Returns
        -------
        torch.Tensor:
            Loss gradient w.r.t. operation input
        Iterable of iterable of torch.Tensor:
161
162
163
            Loss gradients w.r.t. parameters for basic operations
        Iterable of iterable of torch.Tensor:
            Loss gradients w.r.t. extra tensor inputs to basic
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            operations

        """
        raise NotImplementedError(
            f"Backward pass is not implemented for operation ({self.__class__.__name__})"
        )


class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
    """Single tensor operation supported by the operation fuser

    This class holds parameters and state, even if the actual forward
    and backward passes are performed by a fused operation.

    """

180
181
182
183
184
    # Number of extra tensor inputs
    num_extra_inputs: int = 0
    # Number of extra tensor outputs
    num_extra_outputs: int = 0

185
186
187
    def __init__(self) -> None:
        super().__init__()

188
        # Objects for quantization
189
        self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
Jan Bielak's avatar
Jan Bielak committed
190
        self._quantizers: Optional[dict[str, list[Quantizer]]] = None
191
192
193
194
195

    @property
    def is_fused_op(self) -> bool:
        return False

196
    def num_quantizers(
197
198
199
        self,
        mode: str,  # pylint: disable=unused-argument
    ) -> int:
200
201
202
        """Number of quantizers

        Matches number of quantized tensors used in operation.
203
204
205

        Parameters
        ----------
206
207
        mode: {"forward", "backward"}
            Quantizer type
208
209
210
211

        """
        return 0

212
213
214
215
216
    def get_input_quantizer(self) -> Optional[Quantizer]:
        if self.num_quantizers("forward") > 0:
            return self.get_quantizer("forward", 0)
        return None

Jan Bielak's avatar
Jan Bielak committed
217
    def get_grad_output_quantizer(self) -> Optional[Quantizer]:
218
219
220
221
        if self.num_quantizers("backward") > 0:
            return self.get_quantizer("backward", 0)
        return None

222
    def reset_recipe_state(
223
224
        self,
        *,
Jan Bielak's avatar
Jan Bielak committed
225
        recipe: Optional[Recipe],
226
227
228
    ) -> None:
        """Construct state for quantization recipe"""

Jan Bielak's avatar
Jan Bielak committed
229
230
231
232
233
        # Clear quantization state if necessary
        if recipe is None:
            self._fp8_metas = None
            self._quantizers = None
            return
234

235
236
237
        # Communication group for FP8 amax reductions
        fp8_group = FP8GlobalStateManager.get_fp8_group()

Jan Bielak's avatar
Jan Bielak committed
238
239
240
241
242
243
244
245
246
247
        # Skip resetting recipe type if it did not actually change.
        # This could happen for example if calling BasicOperation.forward directly, as in that
        # case, the OperationFuser is not persistent, or when loading from a checkpoint
        need_to_reset_recipe_state = False
        if self._fp8_metas is None or self._quantizers is None:
            need_to_reset_recipe_state = True
        else:
            for mode in ("forward", "backward"):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                    forward=(mode == "forward"),
248
                )
Jan Bielak's avatar
Jan Bielak committed
249
250
251
252
253
254
255
256
                if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
                    continue
                recipe_state = self._fp8_metas[mode][fp8_meta_key]
                if not isinstance(recipe, type(recipe_state.recipe)):
                    need_to_reset_recipe_state = True
                    break

        if need_to_reset_recipe_state:
257
            # Construct quantization recipe states
Jan Bielak's avatar
Jan Bielak committed
258
259
260
261
262
263
            self._fp8_metas = {"forward": None, "backward": None}
            self._quantizers = {"forward": [], "backward": []}
            for mode in ("forward", "backward"):
                num_quantizers = self.num_quantizers(mode)
                if num_quantizers == 0:
                    continue
264

Jan Bielak's avatar
Jan Bielak committed
265
266
267
268
                if recipe.float8_block_scaling():
                    raise NotImplementedError(
                        "Fusible operations do not support FP8 block scaling recipe"
                    )
269

Jan Bielak's avatar
Jan Bielak committed
270
271
272
273
274
275
276
277
278
279
280
281
                # Construct quantization recipe state
                recipe_state = RecipeState.create(
                    recipe,
                    mode=mode,
                    num_quantizers=num_quantizers,
                )
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                    forward=(mode == "forward"),
                )
                self._fp8_metas[mode] = {
                    fp8_meta_key: recipe_state,
                    "recipe": recipe,
282
                    "fp8_group": fp8_group,
Jan Bielak's avatar
Jan Bielak committed
283
                }
284

Jan Bielak's avatar
Jan Bielak committed
285
286
                # Construct builder class for quantized tensors
                self._quantizers[mode] = recipe_state.make_quantizers()
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        else:
            # Update quantization recipe states
            for mode in ("forward", "backward"):
                if self._fp8_metas[mode] is None:
                    continue
                self._fp8_metas[mode]["recipe"] = recipe
                self._fp8_metas[mode]["fp8_group"] = fp8_group

                # Update amax history for FP8 delayed scaling
                if recipe.delayed():
                    fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                        forward=(mode == "forward"),
                    )
                    recipe_state = self._fp8_metas[mode][fp8_meta_key]
301
302

                    # Reallocate amax history if needed
303
304
305
306
307
308
309
310
311
312
313
314
315
                    current_length = recipe_state.amax_history.size(0)
                    target_length = recipe.amax_history_len
                    if target_length < current_length:
                        with torch.no_grad():
                            recipe_state.amax_history = recipe_state.amax_history[
                                :target_length
                            ].clone()
                    elif target_length > current_length:
                        with torch.no_grad():
                            recipe_state.amax_history = torch.nn.functional.pad(
                                recipe_state.amax_history,
                                pad=(0, 0, 0, target_length - current_length),
                            )
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                    # Update quantizers with new amax pointers
                    self._quantizers[mode] = recipe_state.make_quantizers()

                    # Update the global buffers with new amax pointers
                    if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]:
                        pos, buffer_key = self._fp8_metas[mode][
                            FP8GlobalStateManager.get_buffer_info()
                        ]
                        if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                            assert (
                                buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                            ), "TE internal error during amax history change."
                            FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
                                recipe_state.amax_history[0]
                            )
                            FP8GlobalStateManager.global_amax_history_buffer[buffer_key][
                                pos
                            ] = recipe_state.amax_history

Jan Bielak's avatar
Jan Bielak committed
336
        # Add meta tensors to global buffer to participate in reduction
337
        for mode in ("forward", "backward"):
Jan Bielak's avatar
Jan Bielak committed
338
339
340
341
342
343
344
            if (
                FP8GlobalStateManager.is_fp8_enabled()
                and self.num_quantizers(mode)
                and not FP8GlobalStateManager.fp8_graph_capturing()
            ):
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
                    self._fp8_metas[mode],
345
                )
346
347
348
349
350

    def get_quantizer(
        self,
        mode: str,
        index: int,
Jan Bielak's avatar
Jan Bielak committed
351
    ) -> Optional[Quantizer]:
352
        """Get builder class for quantized tensor
353
354
355

        Parameters
        ----------
356
357
        mode: {"forward", "backward"}
            Quantizer type
358
359

        """
360
        if self._quantizers is None:
Jan Bielak's avatar
Jan Bielak committed
361
            return None
362
        return self._quantizers[mode][index]
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    @torch.no_grad()
    def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
        """Create copies of tensors in FP8 metadata

        Tensor copies can be loaded with _load_fp8_metas.

        """
        if self._fp8_metas is None:
            return None
        out = {}
        for mode, fp8_meta in self._fp8_metas.items():
            if fp8_meta is None:
                continue
            out[mode] = {}
            for is_forward in (True, False):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
                if fp8_meta_key not in fp8_meta:
                    continue
                out[mode][fp8_meta_key] = (
                    fp8_meta[fp8_meta_key].scale.clone(),
                    fp8_meta[fp8_meta_key].amax_history.clone(),
                )
        return out

    @torch.no_grad()
    def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
        """Update FP8 metadata with saved tensor copies

        Tensor copies should be generated with _save_fp8_metas.

        """
        assert (self._fp8_metas is None) == (
            fp8_metas is None
        ), "Saved FP8 metadata does not match operation's FP8 metadata"
        if fp8_metas is None:
            return
        for mode, fp8_meta in fp8_metas.items():
            assert (
                mode in self._fp8_metas
            ), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
            for fp8_meta_key, tensors in fp8_meta.items():
                assert (
                    fp8_meta_key in self._fp8_metas[mode]
                ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
408
                scale, amax_history = tensors
409
410
411
                self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
                self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)

412
413
414
415
416
    @abc.abstractmethod
    def op_forward(
        self,
        ctx: OperationContext,
        input_: torch.Tensor,
417
        *,
Jan Bielak's avatar
Jan Bielak committed
418
        prev_op_grad_output_quantizer: Optional[Quantizer],
419
        next_op_input_quantizer: Optional[Quantizer],
420
421
422
423
424
425
426
427
428
429
        **kwargs: Any,
    ) -> torch.Tensor:
        """Forward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        input_: torch.Tensor
            Input tensor
Jan Bielak's avatar
Jan Bielak committed
430
431
        prev_op_grad_output_quantizer: Quantizer, optional
            The grad_output_quantizer of the preceeding operation
432
433
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

        Returns
        -------
        torch.Tensor:
            Output tensor

        """

    @abc.abstractmethod
    def op_backward(
        self,
        ctx: OperationContext,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
        """Backward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        grad_output: torch.Tensor
            Loss gradient w.r.t. operation output

        Returns
        -------
        torch.Tensor
            Loss gradient w.r.t. operation input
        Iterable of torch.Tensor:
            Loss gradients w.r.t. parameters

        """

    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
470
471
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
Jan Bielak's avatar
Jan Bielak committed
472
        prev_op_grad_output_quantizer: Optional[Quantizer],
473
        next_op_input_quantizer: Optional[Quantizer],
474
        basic_op_kwargs: list[dict[str, Any]],
475
476
477
478
479
480
481
482
483
    ) -> tuple[torch.Tensor, list[tuple[()]]]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_forward` instead of `op_forward`."
            )
        output = self.op_forward(
484
485
            basic_op_ctxs[0],
            input_,
Jan Bielak's avatar
Jan Bielak committed
486
            prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
487
            next_op_input_quantizer=next_op_input_quantizer,
488
489
            **basic_op_kwargs[0],
        )
490
        return output, [()]
491
492
493
494
495

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        list[Iterable[Optional[torch.Tensor]]],
        list[tuple[()]],
    ]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_backward` instead of `op_backward`."
            )
510
        grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
511
        return grad_input, [grad_params], [()]
512
513
514
515

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
516
        *extra_inputs: torch.Tensor,
517
        **kwargs: Any,
518
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
519
520
521
        """Apply operation"""
        from .fuser import OperationFuser

Jan Bielak's avatar
Jan Bielak committed
522
        return OperationFuser([self])(
523
524
525
526
            input,
            *extra_inputs,
            basic_op_kwargs=[kwargs],
        )
527

528
    def get_extra_state(self) -> torch.Tensor:
529
530
        """Serialize extra state

531
        Contains metadata for quantization recipe.
532
533
534
535
536
537
538

        """

        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
539
        #     We have experienced problems (e.g. in ONNX export) with
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

563
        # Store quantizer state if needed
564
        state = {}
565
        for mode in ("forward", "backward"):
566

567
568
            # Skip if op has no quantizer state
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
569
                continue
570
571
572

            # Quantizer state
            fp8_meta = self._fp8_metas[mode]
573
            state[mode] = {}
574
            state[mode]["recipe"] = fp8_meta["recipe"]
575

576
577
578
579
580
581
582
583
            # Copy tensors to CPU and store
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
                    state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
                if mode == "backward":
                    state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
                    state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
584
585
586
587
588
589
590
591
592
593
594

            # Store other picklable items
            extra = {}
            for key, val in fp8_meta.items():
                if key == "buffer_index_and_autocast_key":
                    continue
                if not isinstance(val, (bool, int, float, str, tuple, list)):
                    continue
                extra[key] = val
            state[mode]["extra_fp8_variables"] = extra

595
596
597
        if not state:
            return torch.empty(0, dtype=torch.uint8)

598
599
600
601
602
603
604
605
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
        return state_serialized

    def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
        """Load extra state"""
606
        if state is None or state.numel() == 0:
607
608
609
610
            return

        # Deserialize state from byte tensor
        state = pickle.loads(state.detach().numpy(force=True).tobytes())
611
        if state is None or len(state) == 0:
612
613
614
615
616
617
618
619
620
621
622
623
624
            return

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            if src.size() != dst.size():
                dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
            dst.copy_(src, non_blocking=True)

625
        # Load quantizer state if needed
626
        for mode in ("forward", "backward"):
627

628
            # Skip if checkpoint has no quantizer state
629
630
631
            if mode not in state:
                continue

632
633
            # Get op's quantizer state, initializing if needed
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
634
                with autocast(recipe=state[mode]["recipe"]):
635
                    self.reset_recipe_state(recipe=state[mode]["recipe"])
636
637
638
639
            fp8_meta = self._fp8_metas[mode]

            # Load extra items
            fp8_meta["recipe"] = state[mode]["recipe"]
640
641
642
643
644
            fp8_meta.update(state[mode]["extra_fp8_variables"])
            if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
                del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

            # Load tensors
645
646
647
648
649
650
651
652
653
654
655
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
                    )
                if mode == "backward":
                    copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
                    )
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675

        # Finish CPU-GPU memory transfers
        torch.cuda.synchronize()

    def _load_from_state_dict(self, *args, **kwargs) -> None:
        """Load state"""

        # In the base PyTorch module class, the extra state is loaded
        # _after_ the parameters. However, copying values into FP8
        # parameters requires an FP8 cast, which uses a scaling factor
        # from the operation's FP8 metadata. The FP8 metadata is
        # included in the operation's extra state, so we need to
        # manually load the extra state before loading parameters.

        state_dict, prefix = args[0], args[1]
        extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
        if extra_state_key in state_dict:
            self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(*args, **kwargs)

676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709

class FusedOperation(FusibleOperation):
    """Compound tensor operation supported by the operation fuser

    If the forward or backward passes are defined, they must be
    functionally equivalent to the forward/backward passes of the
    corresponding basic ops. This class should hold no parameters or
    other state, but should access them from the basic ops.

    Parameters
    ----------
    basic_ops: iterable of FusibleOperation
        Basic ops that are interchangeable with this op

    """

    def __init__(
        self,
        basic_ops: Iterable[FusibleOperation],
    ) -> None:
        super().__init__()

        # Basic operations that comprise this fused operation
        self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops)
        if len(self.basic_ops) == 0:
            raise ValueError(
                "Attempted to construct a fused operation "
                "without specifying its corresponding basic operations"
            )

    @property
    def is_fused_op(self) -> bool:
        return True

710
711
712
    def get_input_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[0].get_input_quantizer()

Jan Bielak's avatar
Jan Bielak committed
713
714
    def get_grad_output_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[-1].get_grad_output_quantizer()
715

Jan Bielak's avatar
Jan Bielak committed
716
    def pre_first_fuser_forward(self) -> None:
717
        for op in self.basic_ops:
Jan Bielak's avatar
Jan Bielak committed
718
            op.pre_first_fuser_forward()
719

720
721
722
723
    def pre_fuser_forward(self, *, requires_grad: bool) -> None:
        for op in self.basic_ops:
            op.pre_fuser_forward(requires_grad=requires_grad)

724
725
726
    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
727
        *extra_inputs: torch.Tensor,
728
729
730
731
732
733
734
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
    ) -> torch.Tensor:
        """Apply operation"""
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
        from .fuser import OperationFuser

Jan Bielak's avatar
Jan Bielak committed
735
        return OperationFuser([self])(
736
737
738
739
            input,
            *extra_inputs,
            basic_op_kwargs=basic_op_kwargs,
        )