fuser.py 14.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
8
from collections.abc import Callable
9
10
11
12
13
14
15
16
17
18
from typing import Any, Optional

import torch

from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
19
20
from transformer_engine.pytorch.ops.fused import (
    fuse_backward_linear_add,
21
    fuse_forward_linear_bias_activation,
22
    fuse_forward_linear_bias_add,
23
24
    fuse_userbuffers_backward_linear,
    fuse_userbuffers_forward_linear,
25
26
27
)


28
29
30
31
32
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


51
52
53
54
55
56
57
58
59
60
61
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
62
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
63
        input_: torch.Tensor,
64
        fuser: OperationFuser,
65
        basic_op_kwargs: list[dict[str, Any]],
Tim Moon's avatar
Tim Moon committed
66
        is_grad_enabled: bool,
67
68
        *params_and_extra_inputs: torch.nn.Parameter,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
69
70
71
72
73
74
75
76
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
77
78
        fuser: OperationFuser
            Container for the pipeline of operations to run
79
80
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
81
82
        is_grad_enabled: bool
            Should context be saved for backward
83
84
85
86
87
88
89
90
91
92
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
93
94
95
96

        """

        # Operation autograd contexts
97
        basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
98

99
        # Unflatten list of parameters and extra tensor inputs
100
        extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
101
        basic_op_extra_inputs = []
102
        for op in fuser._basic_ops:
103
104
105
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

106
107
        # Apply forward ops
        x = input_
Tim Moon's avatar
Tim Moon committed
108
        requires_grad = is_grad_enabled and x.requires_grad
109
110
        extra_outputs = [None] * fuser._num_basic_ops
        for op, basic_op_idxs in fuser._forward_ops:
111

Tim Moon's avatar
Tim Moon committed
112
113
114
115
116
117
118
119
            # Check if backward op is required
            if is_grad_enabled:
                if not requires_grad:
                    requires_grad = any(param.requires_grad for param in op.parameters())
                if not requires_grad:
                    requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
            for idx in basic_op_idxs:
                basic_op_ctxs[idx].requires_grad = requires_grad
120
121
122
123
124
            if requires_grad != x.requires_grad:
                if requires_grad:
                    x.requires_grad_()
                else:
                    x = x.detach()
Tim Moon's avatar
Tim Moon committed
125

126
            # Forward op
127
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
128
            prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
129
            next_ops = [
130
131
                fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
                for idx in basic_op_idxs
132
            ]
133
            x, fused_op_extra_outputs = op.fuser_forward(
134
135
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
136
137
138
139
                basic_op_extra_inputs=extra_inputs,
                basic_op_prev_ops=prev_ops,
                basic_op_next_ops=next_ops,
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
140
            )
Tim Moon's avatar
Tim Moon committed
141
            x.requires_grad_(requires_grad=requires_grad)
142
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
143
144
                for y in ys:
                    y.requires_grad_(requires_grad=requires_grad)
145
                extra_outputs[idx] = ys
146

147
148
149
150
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
151
            num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
152
153
154
155
156
157
158
159
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        # Save context for backward pass
        if is_grad_enabled:

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
            func_ctx.save_for_backward(*to_save)

            # Other context
175
176
            func_ctx.backward_ops = fuser._backward_ops
            func_ctx.basic_ops = fuser._basic_ops
Tim Moon's avatar
Tim Moon committed
177
            func_ctx.basic_op_ctxs = basic_op_ctxs
178
179
            func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
            func_ctx.num_extra_inputs = fuser._num_extra_inputs
Tim Moon's avatar
Tim Moon committed
180
181
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
            func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
182

183
184
        if extra_outputs_flat:
            return x, *extra_outputs_flat
185
186
187
188
189
190
191
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
192
        *grad_extra_outputs: torch.Tensor,
193
194
195
196
197
198
199
200
201
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

        # Unflatten list of saved tensors
202
        saved_tensors = func_ctx.saved_tensors
203
        for ctx in basic_op_ctxs:
204
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
205
            ctx._saved_tensors_range = None
206
207
208
209
210
211
212
213
214
215
216

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
217
218
219
220

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
221
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
222
223
224
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
225
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
226
227
228
229
                dx = None
                break

            # Backward op
230
231
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
232
233
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
234
                basic_op_grad_extra_outputs=grad_extra_outputs,
235
            )
236
237
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
238
                basic_op_ctxs[idx].saved_tensors = None
239
240
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
241
242
243
244

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
245
            num_params = func_ctx.basic_op_num_params[idx]
246
            if dparams is None:
247
                dparams = [None for _ in range(num_params)]
248
249
            else:
                dparams = list(dparams)
250
            if len(dparams) != num_params:
251
                raise RuntimeError(
252
                    f"Expected op {idx} to generate {num_params} param grads, "
253
254
255
256
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

273
        # Update FP8 scaling factors
274
        if func_ctx.is_first_module and not _is_graph_capturing():
275
276
277
278
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
279
            None,  # fuser
280
            None,  # basic_op_kwargs
Tim Moon's avatar
Tim Moon committed
281
            None,  # is_grad_enabled
282
283
            *grad_params_flat,
            *grad_extra_inputs_flat,
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations
    fuse_ops: bool, default = `True`
        Whether to attempt fusing operations

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
        fuse_ops: bool = True,
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

315
316
317
        # Number of extra tensor inputs
        self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)

318
319
320
321
322
323
324
325
326
327
        # Ops for forward and backward pass
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]
        self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
        self._backward_ops = list(reversed(self._forward_ops))

        # Fuse ops if needed
        if fuse_ops:
            self.fuse_ops()

328
329
330
331
        # Flatten list of parameters
        self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
        self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]

332
333
334
335
336
337
    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
338
        ops = fuse_userbuffers_forward_linear(ops)
339
        ops = fuse_forward_linear_bias_add(ops)
340
341
342
343
344
345
346
347
348
        ops = fuse_forward_linear_bias_activation(ops)
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
349
        ops = fuse_userbuffers_backward_linear(ops)
350
        ops = fuse_backward_linear_add(ops)
351
352
353
354
355
356
357
358
359
360
        return ops

    def fuse_ops(self) -> None:
        """Attempt to fuse operations"""
        self._forward_ops = self._fuse_forward_ops(self._forward_ops)
        self._backward_ops = self._fuse_backward_ops(self._backward_ops)

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
361
        *extra_inputs: torch.Tensor,
362
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
363
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
364
365
366
367
368
        # Verify extra input count
        if len(extra_inputs) != self._num_extra_inputs:
            raise ValueError(
                f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
            )
369
370
371
372
373
374
375

        # Initialization before forward pass
        for op in self._basic_ops:
            op.pre_forward()

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
376
            basic_op_kwargs = [{}] * self._num_basic_ops
377
378

        # Fuser forward pass
Tim Moon's avatar
Tim Moon committed
379
380
381
382
383
384
385
386
        is_grad_enabled = torch.is_grad_enabled()
        if is_grad_enabled:
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
387
            input,
388
            self,
389
            basic_op_kwargs,
Tim Moon's avatar
Tim Moon committed
390
            is_grad_enabled,
391
            *self._basic_op_params,
392
            *extra_inputs,
393
        )
Tim Moon's avatar
Tim Moon committed
394
        return forward_func(*args)