fuser.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
8
from collections.abc import Callable
9
10
11
12
13
14
15
16
17
18
from typing import Any, Optional

import torch

from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
19
20
from transformer_engine.pytorch.ops.fused import (
    fuse_backward_linear_add,
21
    fuse_forward_linear_bias_activation,
22
    fuse_forward_linear_bias_add,
23
24
    fuse_userbuffers_backward_linear,
    fuse_userbuffers_forward_linear,
25
26
27
)


28
29
30
31
32
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


51
52
53
54
55
56
57
58
59
60
61
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
62
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
63
64
65
66
67
        input_: torch.Tensor,
        forward_ops: list[tuple[FusibleOperation, list[int]]],
        backward_ops: list[tuple[FusibleOperation, list[int]]],
        basic_ops: list[BasicOperation],
        basic_op_kwargs: list[dict[str, Any]],
Tim Moon's avatar
Tim Moon committed
68
        is_grad_enabled: bool,
69
70
71
72
        num_params: int,
        num_extra_inputs: int,
        *params_and_extra_inputs: torch.nn.Parameter,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
        forward_ops: list of tuple
            Forward pass operations and the indices of the
            corresponding basic operations. The order should match
            basic_ops.
        backward_ops: list of tuple
            Backward pass operations and the indices of the
            corresponding basic operations. The order should be the
            reverse of basic_ops.
        basic_ops: list of BasicOperation
            Basic operations
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
93
94
95
96
97
98
99
100
101
102
103
104
        num_params: int
            Number of parameter tensors to include in autograd graph.
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
105
106
107
108
109
110

        """

        # Operation autograd contexts
        basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]

111
112
113
114
115
116
117
118
119
120
121
122
123
        # Unflatten list of parameters and extra tensor inputs
        if len(params_and_extra_inputs) != num_params + num_extra_inputs:
            raise ValueError(
                f"Expected {num_params + num_extra_inputs} extra tensor arguments "
                f"({num_params} parameters, {num_extra_inputs} extra inputs), "
                f"but got {len(params_and_extra_inputs)}"
            )
        _, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
        basic_op_extra_inputs = []
        for op in basic_ops:
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

124
125
        # Apply forward ops
        x = input_
Tim Moon's avatar
Tim Moon committed
126
        requires_grad = is_grad_enabled and x.requires_grad
127
        extra_outputs = [None for _ in range(len(basic_ops))]
128
129
        for op, basic_op_idxs in forward_ops:

Tim Moon's avatar
Tim Moon committed
130
131
132
133
134
135
136
137
138
139
            # Check if backward op is required
            if is_grad_enabled:
                if not requires_grad:
                    requires_grad = any(param.requires_grad for param in op.parameters())
                if not requires_grad:
                    requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
            for idx in basic_op_idxs:
                basic_op_ctxs[idx].requires_grad = requires_grad
            x.requires_grad_(requires_grad=requires_grad)

140
            # Forward op
141
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
142
143
144
145
            prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
            next_ops = [
                basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
            ]
146
            x, fused_op_extra_outputs = op.fuser_forward(
147
148
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
149
150
151
152
                basic_op_extra_inputs=extra_inputs,
                basic_op_prev_ops=prev_ops,
                basic_op_next_ops=next_ops,
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
153
            )
Tim Moon's avatar
Tim Moon committed
154
            x.requires_grad_(requires_grad=requires_grad)
155
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
156
157
                for y in ys:
                    y.requires_grad_(requires_grad=requires_grad)
158
                extra_outputs[idx] = ys
159

160
161
162
163
164
165
166
167
168
169
170
171
172
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
            num_extra_outputs = basic_ops[idx].num_extra_outputs
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        # Save context for backward pass
        if is_grad_enabled:

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
            func_ctx.save_for_backward(*to_save)

            # Other context
            func_ctx.backward_ops = backward_ops
            func_ctx.basic_ops = basic_ops
            func_ctx.basic_op_ctxs = basic_op_ctxs
            func_ctx.num_params = num_params
            func_ctx.num_extra_inputs = num_extra_inputs
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
            func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
195

196
197
        if extra_outputs_flat:
            return x, *extra_outputs_flat
198
199
200
201
202
203
204
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
205
        *grad_extra_outputs: torch.Tensor,
206
207
208
209
210
211
212
213
214
215
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

        # Unflatten list of saved tensors
        for ctx in basic_op_ctxs:
216
            ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)]
217
            ctx._saved_tensors_range = None
218
219
220
221
222
223
224
225
226
227
228

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
229
230
231
232

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
233
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
234
235
236
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
237
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
238
239
240
241
                dx = None
                break

            # Backward op
242
243
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
244
245
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
246
                basic_op_grad_extra_outputs=grad_extra_outputs,
247
            )
248
249
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
250
                basic_op_ctxs[idx].saved_tensors = None
251
252
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
            params = list(basic_ops[idx].parameters())
            if dparams is None:
                dparams = [None for _ in range(len(params))]
            else:
                dparams = list(dparams)
            if len(dparams) != len(params):
                raise RuntimeError(
                    f"Expected op {idx} to generate {len(params)} param grads, "
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

285
        # Update FP8 scaling factors
286
        if func_ctx.is_first_module and not _is_graph_capturing():
287
288
289
290
291
292
293
294
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
            None,  # forward_ops
            None,  # backward_ops
            None,  # basic_ops
            None,  # basic_op_kwargs
Tim Moon's avatar
Tim Moon committed
295
            None,  # is_grad_enabled
296
297
298
299
            None,  # num_params
            None,  # num_extra_inputs
            *grad_params_flat,
            *grad_extra_inputs_flat,
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations
    fuse_ops: bool, default = `True`
        Whether to attempt fusing operations

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
        fuse_ops: bool = True,
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

331
332
333
        # Number of extra tensor inputs
        self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        # Ops for forward and backward pass
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]
        self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
        self._backward_ops = list(reversed(self._forward_ops))

        # Fuse ops if needed
        if fuse_ops:
            self.fuse_ops()

    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
350
        ops = fuse_userbuffers_forward_linear(ops)
351
        ops = fuse_forward_linear_bias_add(ops)
352
353
354
355
356
357
358
359
360
        ops = fuse_forward_linear_bias_activation(ops)
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
361
        ops = fuse_userbuffers_backward_linear(ops)
362
        ops = fuse_backward_linear_add(ops)
363
364
365
366
367
368
369
370
371
372
        return ops

    def fuse_ops(self) -> None:
        """Attempt to fuse operations"""
        self._forward_ops = self._fuse_forward_ops(self._forward_ops)
        self._backward_ops = self._fuse_backward_ops(self._backward_ops)

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
373
        *extra_inputs: torch.Tensor,
374
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
375
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
376
377
378
379
380
381
382
383
384
385

        # Initialization before forward pass
        for op in self._basic_ops:
            op.pre_forward()

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]

        # Flatten list of parameters
386
        params = [param for op in self._basic_ops for param in op.parameters()]
387
388

        # Fuser forward pass
Tim Moon's avatar
Tim Moon committed
389
390
391
392
393
394
395
396
        is_grad_enabled = torch.is_grad_enabled()
        if is_grad_enabled:
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
397
398
399
400
401
            input,
            self._forward_ops,
            self._backward_ops,
            self._basic_ops,
            basic_op_kwargs,
Tim Moon's avatar
Tim Moon committed
402
            is_grad_enabled,
403
404
            len(params),
            self._num_extra_inputs,
405
            *params,
406
            *extra_inputs,
407
        )
Tim Moon's avatar
Tim Moon committed
408
        return forward_func(*args)