fuser.py 20.1 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
8
from collections.abc import Callable, Iterable, Sequence
Jan Bielak's avatar
Jan Bielak committed
9
import itertools
10
from typing import Any, Optional, TypeAlias
11
12
13

import torch

14
15
16
from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from ..quantized_tensor import prepare_for_saving, restore_from_saved
from .op import (
17
18
    BasicOperation,
    FusibleOperation,
19
    FusedOperation,
20
21
22
23
    OperationContext,
)


24
25
26
27
28
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


47
48
49
50
51
52
# Type alias for a function that may perform operation fusion
OperationFusionFunction: TypeAlias = (
    "Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]"
)


53
54
55
56
57
58
59
60
61
62
63
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
64
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
65
        input_: torch.Tensor,
66
        fuser: OperationFuser,
67
        basic_op_kwargs: list[dict[str, Any]],
Jan Bielak's avatar
Jan Bielak committed
68
        *params_and_extra_inputs: torch.Tensor,
69
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
70
71
72
73
74
75
76
77
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
78
79
        fuser: OperationFuser
            Container for the pipeline of operations to run
80
81
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
82
83
84
85
86
87
88
89
90
91
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
92
93
94
95

        """

        # Operation autograd contexts
96
        basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
97

98
99
        # Mark input tensors as not deletable in backward
        for tensor in (input_,) + params_and_extra_inputs:
100
            tensor._do_not_clear = True
101

102
        # Unflatten list of parameters and extra tensor inputs
Jan Bielak's avatar
Jan Bielak committed
103
        extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
104
        basic_op_extra_inputs = []
105
        for op in fuser._basic_ops:
106
107
108
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

109
110
        # Apply forward ops
        x = input_
111
112
        extra_outputs = [None] * fuser._num_basic_ops
        for op, basic_op_idxs in fuser._forward_ops:
113

Jan Bielak's avatar
Jan Bielak committed
114
            # Set if backward op is required
Tim Moon's avatar
Tim Moon committed
115
            for idx in basic_op_idxs:
Jan Bielak's avatar
Jan Bielak committed
116
                basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
Tim Moon's avatar
Tim Moon committed
117

118
            # Forward op
119
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
120
            prev_op_idx = basic_op_idxs[0] - 1
121
            prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
Jan Bielak's avatar
Jan Bielak committed
122
123
124
            prev_op_grad_output_quantizer = None
            if prev_op is not None:
                prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
125
126
127
            next_op_idx = basic_op_idxs[-1] + 1
            next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
            next_op_input_quantizer = None
Jan Bielak's avatar
Jan Bielak committed
128
            if next_op is not None:
129
130
                next_op_input_quantizer = next_op.get_input_quantizer()

131
            x, fused_op_extra_outputs = op.fuser_forward(
132
133
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
134
                basic_op_extra_inputs=extra_inputs,
Jan Bielak's avatar
Jan Bielak committed
135
                prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
136
                next_op_input_quantizer=next_op_input_quantizer,
137
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
138
            )
139
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
140
                for y in ys:
Jan Bielak's avatar
Jan Bielak committed
141
                    y.requires_grad_(idx >= fuser.first_op_requiring_backward)
142
                extra_outputs[idx] = ys
143

144
145
146
147
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
148
            num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
149
150
151
152
153
154
155
156
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
157
        # Save context for backward pass
158
        if func_ctx is not None:
Tim Moon's avatar
Tim Moon committed
159
160
161
162
163
164
165
166
167
168

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
169
170

            # Save tensors for backward
171
172
173
            tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
            func_ctx.save_for_backward(*tensors_to_save)
            func_ctx.tensor_objects = tensor_objects
Tim Moon's avatar
Tim Moon committed
174

175
176
177
178
179
            # Whether to perform recipe update in backward pass
            is_first_module = False
            if fuser.first_op_requiring_backward < fuser._num_basic_ops:
                is_first_module = FP8GlobalStateManager.is_first_fp8_module()

Tim Moon's avatar
Tim Moon committed
180
            # Other context
181
182
            func_ctx.backward_ops = fuser._backward_ops
            func_ctx.basic_ops = fuser._basic_ops
Tim Moon's avatar
Tim Moon committed
183
            func_ctx.basic_op_ctxs = basic_op_ctxs
Jan Bielak's avatar
Jan Bielak committed
184
185
            func_ctx.basic_op_num_params = fuser._basic_op_num_params
            func_ctx.num_extra_inputs = fuser.num_extra_inputs
Tim Moon's avatar
Tim Moon committed
186
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
187
            func_ctx.is_first_module = is_first_module
188

189
190
        # Mark output tensors as not deletable in backward
        for tensor in [x] + extra_outputs_flat:
191
            tensor._do_not_clear = True
192

Jan Bielak's avatar
Jan Bielak committed
193
        x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
194

195
196
        if extra_outputs_flat:
            return x, *extra_outputs_flat
197

198
199
200
201
202
203
204
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
205
        *grad_extra_outputs: torch.Tensor,
206
207
208
209
210
211
212
213
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

214
        # Restore saved tensors
215
        saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
216

217
218
        # Unflatten list of saved tensors
        for ctx in basic_op_ctxs:
219
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
220
            ctx._saved_tensors_range = None
221
222
223
224
225
226
227
228
229
230
231

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
232
233
234
235

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
236
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
237
        for op, basic_op_idxs in reversed(backward_ops):
238
239

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
240
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
241
242
243
244
                dx = None
                break

            # Backward op
245
246
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
247
248
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
249
                basic_op_grad_extra_outputs=grad_extra_outputs,
250
            )
251
252
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
253
                basic_op_ctxs[idx].saved_tensors = None
254
255
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
256
257
258
259

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
260
            num_params = func_ctx.basic_op_num_params[idx]
261
            if dparams is None:
262
                dparams = [None for _ in range(num_params)]
263
264
            else:
                dparams = list(dparams)
265
            if len(dparams) != num_params:
266
                raise RuntimeError(
267
                    f"Expected op {idx} to generate {num_params} param grads, "
268
269
270
271
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

288
        # Update FP8 scaling factors
289
        if func_ctx.is_first_module and not _is_graph_capturing():
290
291
292
293
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
294
            None,  # fuser
295
            None,  # basic_op_kwargs
296
297
            *grad_params_flat,
            *grad_extra_inputs_flat,
298
299
300
301
302
303
304
305
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
306
    ops : list of FusibleOperation
307
308
309
310
        Pipeline of operations

    """

311
312
313
314
    # Functions to perform operation fusion
    forward_fusion_functions: list[OperationFusionFunction] = []
    backward_fusion_functions: list[OperationFusionFunction] = []

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    def __init__(
        self,
        ops: list[FusibleOperation],
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

330
        # Number of extra tensor inputs
Jan Bielak's avatar
Jan Bielak committed
331
332
        self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
        self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
333

334
        # Ops for forward and backward pass, will be populated in maybe_fuse_ops
335
336
337
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]

Jan Bielak's avatar
Jan Bielak committed
338
339
340
341
        # Cache and detect change of state relevant for fusing operations
        self.recipe_type = None
        self.first_op_requiring_backward = 0
        self._last_amax_history_len = 0
342

343
        # Flatten list of parameters
Jan Bielak's avatar
Jan Bielak committed
344
345
346
        self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
        self._basic_op_num_params = list(map(len, self._basic_op_params))
        self._flat_basic_op_params = sum(self._basic_op_params, [])
347

348
    @classmethod
349
    def _fuse_ops(
350
        cls,
351
352
        basic_ops: Sequence[BasicOperation],
        fusion_funcs: Iterable[OperationFusionFunction],
353
        recipe: Optional[Recipe],
354
    ) -> list[tuple[FusibleOperation, list[int]]]:
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        """Apply operation fusions"""

        # Apply op fusions
        fused_ops = list(basic_ops)
        for func in fusion_funcs:
            fused_ops = func(fused_ops, recipe=recipe)

        def raise_mismatch_error() -> None:
            """Throw error indicating invalid op fusion"""
            raise RuntimeError(
                "Found mismatch after fusing operations "
                f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, "
                f"fused_ops={[o.__class__.__name__ for o in fused_ops]})"
            )

        # Determine basic op indices corresponding to each op
        out = []
        idx = 0
        for op in fused_ops:
            if isinstance(op, FusedOperation):
                idxs = []
                for basic_op in op.basic_ops:
                    if basic_op is not basic_ops[idx]:
                        raise_mismatch_error()
                    idxs.append(idx)
                    idx += 1
                out.append((op, idxs))
            else:
                if op is not basic_ops[idx]:
                    raise_mismatch_error()
                out.append((op, [idx]))
                idx += 1
        if idx != len(basic_ops):
            raise_mismatch_error()

        return out
391

Jan Bielak's avatar
Jan Bielak committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def maybe_fuse_ops(
        self,
        is_grad_enabled: bool,
        recipe: Optional[Recipe],
        input_: torch.Tensor,
        extra_inputs: list[Iterable[torch.Tensor]],
    ):
        """Attempt to fuse operations if neccesary"""

        # Determine which basic ops require backward
        if not is_grad_enabled:
            first_op_requiring_backward = self._num_basic_ops
        elif input_.requires_grad:
            first_op_requiring_backward = 0
        else:
            first_op_requiring_backward = self._num_basic_ops
            for op_idx in range(self._num_basic_ops):
                op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
                if any(tensor.requires_grad for tensor in op_inputs):
                    first_op_requiring_backward = op_idx
                    break

        # Early exit if fusion parameters haven't changed
415
        need_reset = False
Jan Bielak's avatar
Jan Bielak committed
416
417
        recipe_type = type(recipe)
        fusion_params = (recipe_type, first_op_requiring_backward)
418
419
420
421
422
423
424
425
426
427
428
        if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
            # Recipe type or grad requirmenets have changed
            need_reset = True
        elif (
            recipe is not None
            and recipe.delayed()
            and self._last_amax_history_len != recipe.amax_history_len
        ):
            # FP8 delayed scaling has changed amax history length
            need_reset = True
        if not need_reset:
Jan Bielak's avatar
Jan Bielak committed
429
430
            return

431
432
433
434
435
436
        # Reset recipe state
        for op in self._basic_ops:
            op.reset_recipe_state(recipe=recipe)

        # Check if this is the first iteration
        if self.recipe_type is None:
Jan Bielak's avatar
Jan Bielak committed
437
            for op in self._basic_ops:
438
                op.pre_first_fuser_forward()
Jan Bielak's avatar
Jan Bielak committed
439
440

        # Prepare basic op lists for fusions
441
442
443
444
445
446
447
448
449
450
        self._forward_ops = OperationFuser._fuse_ops(
            self._basic_ops,
            OperationFuser.forward_fusion_functions,
            recipe=recipe,
        )
        self._backward_ops = OperationFuser._fuse_ops(
            self._basic_ops,
            OperationFuser.backward_fusion_functions,
            recipe=recipe,
        )
Jan Bielak's avatar
Jan Bielak committed
451
452
453
454
455
456
457
458
459

        # Save current fusion params
        self.recipe_type, self.first_op_requiring_backward = fusion_params

        # Save amax history length
        if isinstance(recipe, DelayedScaling):
            self._last_amax_history_len = recipe.amax_history_len
        else:
            self._last_amax_history_len = 0
460
461
462
463

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
464
        *extra_inputs: torch.Tensor,
465
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
466
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
467
        # Verify extra input count
Jan Bielak's avatar
Jan Bielak committed
468
        if len(extra_inputs) != self.num_extra_inputs:
469
            raise ValueError(
Jan Bielak's avatar
Jan Bielak committed
470
                f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
471
            )
472
473
474

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
475
            basic_op_kwargs = [{}] * self._num_basic_ops
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        # Unflatten list of extra tensor inputs
        extra_inputs_copy = list(extra_inputs)
        basic_op_extra_inputs = []
        for op in self._basic_ops:
            xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

        # Get environment state
        recipe = None
        if FP8GlobalStateManager.is_fp8_enabled():
            recipe = FP8GlobalStateManager.get_fp8_recipe()
        is_grad_enabled = torch.is_grad_enabled()

        # Attempt to fuse operations if neccesary
        self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)

493
494
495
496
        # Initialization before forward
        for idx, op in enumerate(self._basic_ops):
            op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward)

497
        # Fuser forward pass
498
        if is_grad_enabled:
Tim Moon's avatar
Tim Moon committed
499
500
501
502
503
504
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
505
            input,
506
            self,
507
            basic_op_kwargs,
Jan Bielak's avatar
Jan Bielak committed
508
            *self._flat_basic_op_params,
509
            *extra_inputs,
510
        )
Tim Moon's avatar
Tim Moon committed
511
        return forward_func(*args)
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563


def register_forward_fusion(
    op_fusion_func: OperationFusionFunction,
    prepend: bool = False,
) -> None:
    """Register function to perform operation fusion for forward pass.

    The fusion function should have the following signature:

        func(ops, *, recipe) -> updated ops

    Parameters
    ----------
    op_fusion_func: function
        Function that takes a list of operations and may substitute
        them with fused operations.
    prepend: bool, default = ``False``
        Whether the operation fuser should apply this fusion function
        first. The default is to apply it last.

    """
    if prepend:
        OperationFuser.forward_fusion_functions.insert(0, op_fusion_func)
    else:
        OperationFuser.forward_fusion_functions.append(op_fusion_func)


def register_backward_fusion(
    op_fusion_func: OperationFusionFunction,
    prepend: bool = False,
) -> None:
    """Register function to perform operation fusion for backward pass.

    The fusion function should have the following signature:

        func(ops, *, recipe) -> updated ops

    Parameters
    ----------
    op_fusion_func: function
        Function that takes a list of operations and may substitute
        them with fused operations.
    prepend: bool, default = ``False``
        Whether the operation fuser should apply this fusion function
        first. The default is to apply it last.

    """
    if prepend:
        OperationFuser.backward_fusion_functions.insert(0, op_fusion_func)
    else:
        OperationFuser.backward_fusion_functions.append(op_fusion_func)