fuser.py 15.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
8
from collections.abc import Callable
9
10
11
12
from typing import Any, Optional

import torch

13
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
14
15
16
17
18
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
19
from transformer_engine.pytorch.ops.fused import (
20
    fuse_backward_bias_activation,
21
    fuse_backward_linear_add,
22
    fuse_forward_linear_bias_activation,
23
    fuse_forward_linear_bias_add,
24
25
    fuse_userbuffers_backward_linear,
    fuse_userbuffers_forward_linear,
26
)
27
28
29
30
from transformer_engine.pytorch.tensor.quantized_tensor import (
    prepare_for_saving,
    restore_from_saved,
)
31
32


33
34
35
36
37
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


56
57
58
59
60
61
62
63
64
65
66
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
67
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
68
        input_: torch.Tensor,
69
        fuser: OperationFuser,
70
        basic_op_kwargs: list[dict[str, Any]],
Tim Moon's avatar
Tim Moon committed
71
        is_grad_enabled: bool,
72
73
        *params_and_extra_inputs: torch.nn.Parameter,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
74
75
76
77
78
79
80
81
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
82
83
        fuser: OperationFuser
            Container for the pipeline of operations to run
84
85
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
86
87
        is_grad_enabled: bool
            Should context be saved for backward
88
89
90
91
92
93
94
95
96
97
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
98
99
100
101

        """

        # Operation autograd contexts
102
        basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
103

104
105
106
107
        # Mark input tensors as not deletable in backward
        for tensor in (input_,) + params_and_extra_inputs:
            tensor.do_not_clear = True

108
        # Unflatten list of parameters and extra tensor inputs
109
        extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
110
        basic_op_extra_inputs = []
111
        for op in fuser._basic_ops:
112
113
114
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

115
116
        # Apply forward ops
        x = input_
Tim Moon's avatar
Tim Moon committed
117
        requires_grad = is_grad_enabled and x.requires_grad
118
        with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
119
120
        extra_outputs = [None] * fuser._num_basic_ops
        for op, basic_op_idxs in fuser._forward_ops:
121

Tim Moon's avatar
Tim Moon committed
122
123
124
125
126
127
128
129
130
            # Check if backward op is required
            if is_grad_enabled:
                if not requires_grad:
                    requires_grad = any(param.requires_grad for param in op.parameters())
                if not requires_grad:
                    requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
            for idx in basic_op_idxs:
                basic_op_ctxs[idx].requires_grad = requires_grad

131
            # Forward op
132
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
133
            prev_op_idx = basic_op_idxs[0] - 1
134
            prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
135
            prev_op_grad_input_quantizer = None
136
            if prev_op is not None and with_quantized_compute:
137
138
139
140
                prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
            next_op_idx = basic_op_idxs[-1] + 1
            next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
            next_op_input_quantizer = None
141
            if next_op is not None and with_quantized_compute:
142
143
                next_op_input_quantizer = next_op.get_input_quantizer()

144
            x, fused_op_extra_outputs = op.fuser_forward(
145
146
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
147
                basic_op_extra_inputs=extra_inputs,
148
149
                prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
                next_op_input_quantizer=next_op_input_quantizer,
150
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
151
            )
152
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
153
                for y in ys:
154
                    y.requires_grad_(requires_grad)
155
                extra_outputs[idx] = ys
156

157
158
159
160
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
161
            num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
162
163
164
165
166
167
168
169
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
170
171
172
173
174
175
176
177
178
179
180
181
        # Save context for backward pass
        if is_grad_enabled:

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
182
183
184
185
186
187
188
189

            # Save tensors for backward
            if with_quantized_compute:
                tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
                func_ctx.save_for_backward(*tensors_to_save)
                func_ctx.tensor_objects = tensor_objects
            else:
                func_ctx.save_for_backward(*to_save)
Tim Moon's avatar
Tim Moon committed
190
191

            # Other context
192
193
            func_ctx.backward_ops = fuser._backward_ops
            func_ctx.basic_ops = fuser._basic_ops
Tim Moon's avatar
Tim Moon committed
194
            func_ctx.basic_op_ctxs = basic_op_ctxs
195
196
            func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
            func_ctx.num_extra_inputs = fuser._num_extra_inputs
Tim Moon's avatar
Tim Moon committed
197
198
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
            func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
199
            func_ctx.with_quantized_compute = with_quantized_compute
200

201
202
        x.requires_grad_(requires_grad)

203
204
        if extra_outputs_flat:
            return x, *extra_outputs_flat
205

206
207
208
209
210
211
212
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
213
        *grad_extra_outputs: torch.Tensor,
214
215
216
217
218
219
220
221
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

222
223
224
225
226
227
        # Restore saved tensors
        if func_ctx.with_quantized_compute:
            saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
        else:
            saved_tensors = func_ctx.saved_tensors

228
229
        # Unflatten list of saved tensors
        for ctx in basic_op_ctxs:
230
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
231
            ctx._saved_tensors_range = None
232
233
234
235
236
237
238
239
240
241
242

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
243
244
245
246

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
247
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
248
249
250
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
251
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
252
253
254
255
                dx = None
                break

            # Backward op
256
257
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
258
259
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
260
                basic_op_grad_extra_outputs=grad_extra_outputs,
261
            )
262
263
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
264
                basic_op_ctxs[idx].saved_tensors = None
265
266
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
267
268
269
270

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
271
            num_params = func_ctx.basic_op_num_params[idx]
272
            if dparams is None:
273
                dparams = [None for _ in range(num_params)]
274
275
            else:
                dparams = list(dparams)
276
            if len(dparams) != num_params:
277
                raise RuntimeError(
278
                    f"Expected op {idx} to generate {num_params} param grads, "
279
280
281
282
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

299
        # Update FP8 scaling factors
300
        if func_ctx.is_first_module and not _is_graph_capturing():
301
302
303
304
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
305
            None,  # fuser
306
            None,  # basic_op_kwargs
Tim Moon's avatar
Tim Moon committed
307
            None,  # is_grad_enabled
308
309
            *grad_params_flat,
            *grad_extra_inputs_flat,
310
311
312
313
314
315
316
317
318
319
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations
320
    fuse_ops: bool
321
        Whether to attempt fusing operations
322
323
324
    recipe: Recipe, optional
        Quantization recipe to use when fusing and executing operations.
        Note: certain fusions may depend on what kind of recipe is being used.
325
326
327
328
329
330

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
331
332
        fuse_ops: bool,
        recipe: Optional[Recipe],
333
334
335
336
337
338
339
340
341
342
343
344
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

345
346
347
        # Number of extra tensor inputs
        self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)

348
349
350
351
352
353
        # Ops for forward and backward pass
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]
        self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
        self._backward_ops = list(reversed(self._forward_ops))

354
355
356
        # Flag for checking if this is the first iteration
        self._is_first_forward = True

357
        # Fuse ops if needed
358
        self.recipe = recipe
359
360
361
        if fuse_ops:
            self.fuse_ops()

362
363
364
365
        # Flatten list of parameters
        self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
        self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]

366
367
368
369
    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
370
        recipe: Optional[Recipe],  # pylint: disable=unused-argument
371
372
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
373
        ops = fuse_userbuffers_forward_linear(ops)
374
        ops = fuse_forward_linear_bias_add(ops)
375
376
377
378
379
380
381
        ops = fuse_forward_linear_bias_activation(ops)
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
382
        recipe: Optional[Recipe],
383
384
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
385
        ops = fuse_userbuffers_backward_linear(ops)
386
        ops = fuse_backward_linear_add(ops)
387
        ops = fuse_backward_bias_activation(ops, recipe)
388
389
390
391
        return ops

    def fuse_ops(self) -> None:
        """Attempt to fuse operations"""
392
393
        self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe)
        self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe)
394
395
396
397

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
398
        *extra_inputs: torch.Tensor,
399
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
400
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
401
402
403
404
405
        # Verify extra input count
        if len(extra_inputs) != self._num_extra_inputs:
            raise ValueError(
                f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
            )
406
407

        # Initialization before forward pass
408
409
        if self._is_first_forward:
            for op in self._basic_ops:
410
                op.pre_first_forward(recipe=self.recipe)
411
            self._is_first_forward = False
412
413
414

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
415
            basic_op_kwargs = [{}] * self._num_basic_ops
416
417

        # Fuser forward pass
Tim Moon's avatar
Tim Moon committed
418
419
420
421
422
423
424
425
        is_grad_enabled = torch.is_grad_enabled()
        if is_grad_enabled:
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
426
            input,
427
            self,
428
            basic_op_kwargs,
Tim Moon's avatar
Tim Moon committed
429
            is_grad_enabled,
430
            *self._basic_op_params,
431
            *extra_inputs,
432
        )
Tim Moon's avatar
Tim Moon committed
433
        return forward_func(*args)