op.py 26.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""Base classes for fusible operations."""

from __future__ import annotations
import abc
from collections.abc import Iterable
import dataclasses
11
import pickle
12
13
14
15
from typing import Any, Optional

import torch

16
17
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
18
    FP8GlobalStateManager,
19
    RecipeState,
20
    fp8_autocast,
21
)
22
from ..tensor import Quantizer
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


@dataclasses.dataclass
class OperationContext:
    """State needed to apply an operation

    Saves state from forward pass for use in backward pass.

    """

    # Tensors that have been saved from forward function
    # Note: Available in the backward function, matching tensors from
    # to_save.
    saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None
    # Tensors to save for backward function
    # Note: Expected to be set in the forward function, either
    # directly or with save_for_backward.
    to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None

    # Corresponding range in pipeline's list of saved tensors
    _saved_tensors_range: Optional[tuple[int, int]] = None

    # Whether backward pass is required
Tim Moon's avatar
Tim Moon committed
46
    requires_grad: bool = True
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
        """Register tensors to be saved for the backward function

        Expected to be called in the forward function.

        """
        self.to_save = tensors


class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
    """Tensor operation supported by the operation fuser"""

    @property
    @abc.abstractmethod
    def is_fused_op(self) -> bool:
        """Whether this op is the fusion of one or more basic ops"""

Jan Bielak's avatar
Jan Bielak committed
65
66
    def pre_first_fuser_forward(self) -> None:
        """Preprocessing before first fuser forward pass"""
67

68
69
70
    def get_input_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized input tensor"""

Jan Bielak's avatar
Jan Bielak committed
71
72
    def get_grad_output_quantizer(self) -> Optional[Quantizer]:
        """Get builder class for quantized output's grad tensor"""
73

74
75
76
77
    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
78
79
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
Jan Bielak's avatar
Jan Bielak committed
80
        prev_op_grad_output_quantizer: Optional[Quantizer],
81
        next_op_input_quantizer: Optional[Quantizer],
82
        basic_op_kwargs: list[dict[str, Any]],
83
    ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
84
85
86
87
88
89
90
91
92
93
94
        """Forward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        forward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
95
            Contexts for basic operations
96
97
        input_: torch.Tensor
            Input tensor
98
99
        basic_op_extra_inputs: list of torch.Tensor
            Extra tensor inputs to basic operations
Jan Bielak's avatar
Jan Bielak committed
100
101
        prev_op_grad_output_quantizer: Quantizer, optional
            The grad_output_quantizer of the preceeding operation
102
103
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
104
        basic_op_kwargs: list of dict
105
106
            Keyword arguments to forward functions of basic
            operations.
107
108
109

        Returns
        -------
110
111
112
113
        torch.Tensor:
            Output tensor.
        Iterable of torch.Tensor:
            Extra tensor outputs from basic operations.
114
115
116
117
118
119
120
121
122
123

        """
        raise NotImplementedError(
            f"Forward pass is not implemented for operation ({self.__class__.__name__})"
        )

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
124
125
126
127
128
129
130
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        Iterable[Iterable[Optional[torch.Tensor]]],
        Iterable[Iterable[Optional[torch.Tensor]]],
    ]:
131
132
133
134
135
136
137
138
139
140
141
        """Backward pass

        This op is either a basic op or the fusion of basic ops, so
        several of this function's arguments are lists of arguments to
        backward functions of corresponding basic ops.

        Called by `OperationFuser`.

        Parameters
        ----------
        basic_op_ctxs: list of OperationContext
142
            Contexts for basic operations
143
        grad_output: torch.Tensor
144
145
146
147
            Loss gradient w.r.t. operation output
        basic_op_grad_extra_outputs: list of tuple of torch.Tensor
            Loss gradients w.r.t. extra tensor outputs from basic
            operations.
148
149
150
151
152
153

        Returns
        -------
        torch.Tensor:
            Loss gradient w.r.t. operation input
        Iterable of iterable of torch.Tensor:
154
155
156
            Loss gradients w.r.t. parameters for basic operations
        Iterable of iterable of torch.Tensor:
            Loss gradients w.r.t. extra tensor inputs to basic
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
            operations

        """
        raise NotImplementedError(
            f"Backward pass is not implemented for operation ({self.__class__.__name__})"
        )


class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
    """Single tensor operation supported by the operation fuser

    This class holds parameters and state, even if the actual forward
    and backward passes are performed by a fused operation.

    """

173
174
175
176
177
    # Number of extra tensor inputs
    num_extra_inputs: int = 0
    # Number of extra tensor outputs
    num_extra_outputs: int = 0

178
179
180
    def __init__(self) -> None:
        super().__init__()

181
        # Objects for quantization
182
        self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
Jan Bielak's avatar
Jan Bielak committed
183
184
185
        self._quantizers: Optional[dict[str, list[Quantizer]]] = None
        with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
        recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
186
        self.reset_recipe_state(recipe=recipe)
187
188
189
190
191

    @property
    def is_fused_op(self) -> bool:
        return False

192
    def num_quantizers(
193
194
195
        self,
        mode: str,  # pylint: disable=unused-argument
    ) -> int:
196
197
198
        """Number of quantizers

        Matches number of quantized tensors used in operation.
199
200
201

        Parameters
        ----------
202
203
        mode: {"forward", "backward"}
            Quantizer type
204
205
206
207

        """
        return 0

208
209
210
211
212
    def get_input_quantizer(self) -> Optional[Quantizer]:
        if self.num_quantizers("forward") > 0:
            return self.get_quantizer("forward", 0)
        return None

Jan Bielak's avatar
Jan Bielak committed
213
    def get_grad_output_quantizer(self) -> Optional[Quantizer]:
214
215
216
217
        if self.num_quantizers("backward") > 0:
            return self.get_quantizer("backward", 0)
        return None

218
    def reset_recipe_state(
219
220
        self,
        *,
Jan Bielak's avatar
Jan Bielak committed
221
        recipe: Optional[Recipe],
222
223
224
    ) -> None:
        """Construct state for quantization recipe"""

Jan Bielak's avatar
Jan Bielak committed
225
226
227
228
229
        # Clear quantization state if necessary
        if recipe is None:
            self._fp8_metas = None
            self._quantizers = None
            return
230

231
232
233
        # Communication group for FP8 amax reductions
        fp8_group = FP8GlobalStateManager.get_fp8_group()

Jan Bielak's avatar
Jan Bielak committed
234
235
236
237
238
239
240
241
242
243
        # Skip resetting recipe type if it did not actually change.
        # This could happen for example if calling BasicOperation.forward directly, as in that
        # case, the OperationFuser is not persistent, or when loading from a checkpoint
        need_to_reset_recipe_state = False
        if self._fp8_metas is None or self._quantizers is None:
            need_to_reset_recipe_state = True
        else:
            for mode in ("forward", "backward"):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                    forward=(mode == "forward"),
244
                )
Jan Bielak's avatar
Jan Bielak committed
245
246
247
248
249
250
251
252
                if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
                    continue
                recipe_state = self._fp8_metas[mode][fp8_meta_key]
                if not isinstance(recipe, type(recipe_state.recipe)):
                    need_to_reset_recipe_state = True
                    break

        if need_to_reset_recipe_state:
253
            # Construct quantization recipe states
Jan Bielak's avatar
Jan Bielak committed
254
255
256
257
258
259
            self._fp8_metas = {"forward": None, "backward": None}
            self._quantizers = {"forward": [], "backward": []}
            for mode in ("forward", "backward"):
                num_quantizers = self.num_quantizers(mode)
                if num_quantizers == 0:
                    continue
260

Jan Bielak's avatar
Jan Bielak committed
261
262
263
264
                if recipe.float8_block_scaling():
                    raise NotImplementedError(
                        "Fusible operations do not support FP8 block scaling recipe"
                    )
265

Jan Bielak's avatar
Jan Bielak committed
266
267
268
269
270
271
272
273
274
275
276
277
                # Construct quantization recipe state
                recipe_state = RecipeState.create(
                    recipe,
                    mode=mode,
                    num_quantizers=num_quantizers,
                )
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                    forward=(mode == "forward"),
                )
                self._fp8_metas[mode] = {
                    fp8_meta_key: recipe_state,
                    "recipe": recipe,
278
                    "fp8_group": fp8_group,
Jan Bielak's avatar
Jan Bielak committed
279
                }
280

Jan Bielak's avatar
Jan Bielak committed
281
282
                # Construct builder class for quantized tensors
                self._quantizers[mode] = recipe_state.make_quantizers()
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        else:
            # Update quantization recipe states
            for mode in ("forward", "backward"):
                if self._fp8_metas[mode] is None:
                    continue
                self._fp8_metas[mode]["recipe"] = recipe
                self._fp8_metas[mode]["fp8_group"] = fp8_group

                # Update amax history for FP8 delayed scaling
                if recipe.delayed():
                    fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                        forward=(mode == "forward"),
                    )
                    recipe_state = self._fp8_metas[mode][fp8_meta_key]
297
298

                    # Reallocate amax history if needed
299
300
301
302
303
304
305
306
307
308
309
310
311
                    current_length = recipe_state.amax_history.size(0)
                    target_length = recipe.amax_history_len
                    if target_length < current_length:
                        with torch.no_grad():
                            recipe_state.amax_history = recipe_state.amax_history[
                                :target_length
                            ].clone()
                    elif target_length > current_length:
                        with torch.no_grad():
                            recipe_state.amax_history = torch.nn.functional.pad(
                                recipe_state.amax_history,
                                pad=(0, 0, 0, target_length - current_length),
                            )
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                    # Update quantizers with new amax pointers
                    self._quantizers[mode] = recipe_state.make_quantizers()

                    # Update the global buffers with new amax pointers
                    if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]:
                        pos, buffer_key = self._fp8_metas[mode][
                            FP8GlobalStateManager.get_buffer_info()
                        ]
                        if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                            assert (
                                buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                            ), "TE internal error during amax history change."
                            FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
                                recipe_state.amax_history[0]
                            )
                            FP8GlobalStateManager.global_amax_history_buffer[buffer_key][
                                pos
                            ] = recipe_state.amax_history

Jan Bielak's avatar
Jan Bielak committed
332
        # Add meta tensors to global buffer to participate in reduction
333
        for mode in ("forward", "backward"):
Jan Bielak's avatar
Jan Bielak committed
334
335
336
337
338
339
340
            if (
                FP8GlobalStateManager.is_fp8_enabled()
                and self.num_quantizers(mode)
                and not FP8GlobalStateManager.fp8_graph_capturing()
            ):
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
                    self._fp8_metas[mode],
341
                )
342
343
344
345
346

    def get_quantizer(
        self,
        mode: str,
        index: int,
Jan Bielak's avatar
Jan Bielak committed
347
    ) -> Optional[Quantizer]:
348
        """Get builder class for quantized tensor
349
350
351

        Parameters
        ----------
352
353
        mode: {"forward", "backward"}
            Quantizer type
354
355

        """
356
        if self._quantizers is None:
Jan Bielak's avatar
Jan Bielak committed
357
            return None
358
        return self._quantizers[mode][index]
359

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    @torch.no_grad()
    def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
        """Create copies of tensors in FP8 metadata

        Tensor copies can be loaded with _load_fp8_metas.

        """
        if self._fp8_metas is None:
            return None
        out = {}
        for mode, fp8_meta in self._fp8_metas.items():
            if fp8_meta is None:
                continue
            out[mode] = {}
            for is_forward in (True, False):
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
                if fp8_meta_key not in fp8_meta:
                    continue
                out[mode][fp8_meta_key] = (
                    fp8_meta[fp8_meta_key].scale.clone(),
                    fp8_meta[fp8_meta_key].amax_history.clone(),
                )
        return out

    @torch.no_grad()
    def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
        """Update FP8 metadata with saved tensor copies

        Tensor copies should be generated with _save_fp8_metas.

        """
        assert (self._fp8_metas is None) == (
            fp8_metas is None
        ), "Saved FP8 metadata does not match operation's FP8 metadata"
        if fp8_metas is None:
            return
        for mode, fp8_meta in fp8_metas.items():
            assert (
                mode in self._fp8_metas
            ), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
            for fp8_meta_key, tensors in fp8_meta.items():
                assert (
                    fp8_meta_key in self._fp8_metas[mode]
                ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
404
                scale, amax_history = tensors
405
406
407
                self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
                self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)

408
409
410
411
412
    @abc.abstractmethod
    def op_forward(
        self,
        ctx: OperationContext,
        input_: torch.Tensor,
413
        *,
Jan Bielak's avatar
Jan Bielak committed
414
        prev_op_grad_output_quantizer: Optional[Quantizer],
415
        next_op_input_quantizer: Optional[Quantizer],
416
417
418
419
420
421
422
423
424
425
        **kwargs: Any,
    ) -> torch.Tensor:
        """Forward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        input_: torch.Tensor
            Input tensor
Jan Bielak's avatar
Jan Bielak committed
426
427
        prev_op_grad_output_quantizer: Quantizer, optional
            The grad_output_quantizer of the preceeding operation
428
429
        next_op_input_quantizer: Quantizer, optional
            The input_quantizer of the following operation
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

        Returns
        -------
        torch.Tensor:
            Output tensor

        """

    @abc.abstractmethod
    def op_backward(
        self,
        ctx: OperationContext,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
        """Backward pass

        Parameters
        ----------
        ctx: OperationContext
            Context to coordinate between forward and backward passes
        grad_output: torch.Tensor
            Loss gradient w.r.t. operation output

        Returns
        -------
        torch.Tensor
            Loss gradient w.r.t. operation input
        Iterable of torch.Tensor:
            Loss gradients w.r.t. parameters

        """

    def fuser_forward(
        self,
        basic_op_ctxs: list[OperationContext],
        input_: torch.Tensor,
466
467
        *,
        basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
Jan Bielak's avatar
Jan Bielak committed
468
        prev_op_grad_output_quantizer: Optional[Quantizer],
469
        next_op_input_quantizer: Optional[Quantizer],
470
        basic_op_kwargs: list[dict[str, Any]],
471
472
473
474
475
476
477
478
479
    ) -> tuple[torch.Tensor, list[tuple[()]]]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_forward` instead of `op_forward`."
            )
        output = self.op_forward(
480
481
            basic_op_ctxs[0],
            input_,
Jan Bielak's avatar
Jan Bielak committed
482
            prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
483
            next_op_input_quantizer=next_op_input_quantizer,
484
485
            **basic_op_kwargs[0],
        )
486
        return output, [()]
487
488
489
490
491

    def fuser_backward(
        self,
        basic_op_ctxs: list[OperationContext],
        grad_output: torch.Tensor,
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        *,
        basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
    ) -> tuple[
        torch.Tensor,
        list[Iterable[Optional[torch.Tensor]]],
        list[tuple[()]],
    ]:
        if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
            raise RuntimeError(
                "{self.__class__.__name__} operation has "
                f"{self.num_extra_inputs} extra tensor inputs "
                f"and {self.num_extra_outputs} extra tensor outputs. "
                "It should override `fuser_backward` instead of `op_backward`."
            )
506
        grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
507
        return grad_input, [grad_params], [()]
508
509
510
511

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
512
        *extra_inputs: torch.Tensor,
513
        **kwargs: Any,
514
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
515
516
517
        """Apply operation"""
        from .fuser import OperationFuser

Jan Bielak's avatar
Jan Bielak committed
518
        return OperationFuser([self])(
519
520
521
522
            input,
            *extra_inputs,
            basic_op_kwargs=[kwargs],
        )
523

524
    def get_extra_state(self) -> torch.Tensor:
525
526
        """Serialize extra state

527
        Contains metadata for quantization recipe.
528
529
530
531
532
533
534

        """

        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
535
        #     We have experienced problems (e.g. in ONNX export) with
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

559
        # Store quantizer state if needed
560
        state = {}
561
        for mode in ("forward", "backward"):
562

563
564
            # Skip if op has no quantizer state
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
565
                continue
566
567
568

            # Quantizer state
            fp8_meta = self._fp8_metas[mode]
569
            state[mode] = {}
570
            state[mode]["recipe"] = fp8_meta["recipe"]
571

572
573
574
575
576
577
578
579
            # Copy tensors to CPU and store
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
                    state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
                if mode == "backward":
                    state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
                    state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598

            # Store other picklable items
            extra = {}
            for key, val in fp8_meta.items():
                if key == "buffer_index_and_autocast_key":
                    continue
                if not isinstance(val, (bool, int, float, str, tuple, list)):
                    continue
                extra[key] = val
            state[mode]["extra_fp8_variables"] = extra

        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
        return state_serialized

    def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
        """Load extra state"""
599
        if state is None or state.numel() == 0:
600
601
602
603
            return

        # Deserialize state from byte tensor
        state = pickle.loads(state.detach().numpy(force=True).tobytes())
604
        if state is None or len(state) == 0:
605
606
607
608
609
610
611
612
613
614
615
616
617
            return

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            if src.size() != dst.size():
                dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
            dst.copy_(src, non_blocking=True)

618
        # Load quantizer state if needed
619
        for mode in ("forward", "backward"):
620

621
            # Skip if checkpoint has no quantizer state
622
623
624
            if mode not in state:
                continue

625
626
627
            # Get op's quantizer state, initializing if needed
            if self._fp8_metas is None or self._fp8_metas[mode] is None:
                with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
628
                    self.reset_recipe_state(recipe=state[mode]["recipe"])
629
630
631
632
            fp8_meta = self._fp8_metas[mode]

            # Load extra items
            fp8_meta["recipe"] = state[mode]["recipe"]
633
634
635
636
637
            fp8_meta.update(state[mode]["extra_fp8_variables"])
            if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
                del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

            # Load tensors
638
639
640
641
642
643
644
645
646
647
648
            if state[mode]["recipe"].delayed():
                if mode == "forward":
                    copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
                    )
                if mode == "backward":
                    copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
                    copy_tensor(
                        state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
                    )
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668

        # Finish CPU-GPU memory transfers
        torch.cuda.synchronize()

    def _load_from_state_dict(self, *args, **kwargs) -> None:
        """Load state"""

        # In the base PyTorch module class, the extra state is loaded
        # _after_ the parameters. However, copying values into FP8
        # parameters requires an FP8 cast, which uses a scaling factor
        # from the operation's FP8 metadata. The FP8 metadata is
        # included in the operation's extra state, so we need to
        # manually load the extra state before loading parameters.

        state_dict, prefix = args[0], args[1]
        extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
        if extra_state_key in state_dict:
            self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(*args, **kwargs)

669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702

class FusedOperation(FusibleOperation):
    """Compound tensor operation supported by the operation fuser

    If the forward or backward passes are defined, they must be
    functionally equivalent to the forward/backward passes of the
    corresponding basic ops. This class should hold no parameters or
    other state, but should access them from the basic ops.

    Parameters
    ----------
    basic_ops: iterable of FusibleOperation
        Basic ops that are interchangeable with this op

    """

    def __init__(
        self,
        basic_ops: Iterable[FusibleOperation],
    ) -> None:
        super().__init__()

        # Basic operations that comprise this fused operation
        self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops)
        if len(self.basic_ops) == 0:
            raise ValueError(
                "Attempted to construct a fused operation "
                "without specifying its corresponding basic operations"
            )

    @property
    def is_fused_op(self) -> bool:
        return True

703
704
705
    def get_input_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[0].get_input_quantizer()

Jan Bielak's avatar
Jan Bielak committed
706
707
    def get_grad_output_quantizer(self) -> Optional[Quantizer]:
        return self.basic_ops[-1].get_grad_output_quantizer()
708

Jan Bielak's avatar
Jan Bielak committed
709
    def pre_first_fuser_forward(self) -> None:
710
        for op in self.basic_ops:
Jan Bielak's avatar
Jan Bielak committed
711
            op.pre_first_fuser_forward()
712
713
714
715

    def forward(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
716
        *extra_inputs: torch.Tensor,
717
718
719
720
721
722
723
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
    ) -> torch.Tensor:
        """Apply operation"""
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
        from .fuser import OperationFuser

Jan Bielak's avatar
Jan Bielak committed
724
        return OperationFuser([self])(
725
726
727
728
            input,
            *extra_inputs,
            basic_op_kwargs=basic_op_kwargs,
        )