fuser.py 8.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
from typing import Any, Optional

import torch

from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
from transformer_engine.pytorch.ops.fused_forward import (
    fuse_forward_linear_bias_activation,
)


class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
        func_ctx: torch.autograd.function.FunctionCtx,
        input_: torch.Tensor,
        forward_ops: list[tuple[FusibleOperation, list[int]]],
        backward_ops: list[tuple[FusibleOperation, list[int]]],
        basic_ops: list[BasicOperation],
        basic_op_kwargs: list[dict[str, Any]],
        *params: torch.nn.Parameter,
    ) -> torch.Tensor:
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
        forward_ops: list of tuple
            Forward pass operations and the indices of the
            corresponding basic operations. The order should match
            basic_ops.
        backward_ops: list of tuple
            Backward pass operations and the indices of the
            corresponding basic operations. The order should be the
            reverse of basic_ops.
        basic_ops: list of BasicOperation
            Basic operations
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
        *params: torch.nn.Parameter
            Parameters in operation pipeline

        """

        # Operation autograd contexts
        basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]

        # Apply forward ops
        x = input_
        requires_grad = x.requires_grad
        for op, basic_op_idxs in forward_ops:

            # Forward op
            prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
            next_ops = [
                basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
            ]
            x = op.fuser_forward(
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
                prev_ops,
                next_ops,
                [basic_op_kwargs[idx] for idx in basic_op_idxs],
            )

            # Check if backward op is required
            if not requires_grad:
                requires_grad = any(param.requires_grad for param in op.parameters())
            for idx in basic_op_idxs:
                basic_op_ctxs[idx]._requires_grad = requires_grad
            x.requires_grad_(requires_grad=requires_grad)

        # Flatten list of saved tensors
        to_save = []
        for ctx in basic_op_ctxs:
            range_start = len(to_save)
            if ctx.to_save is not None:
                to_save.extend(ctx.to_save)
            range_end = len(to_save)
            ctx.to_save = None
            ctx._saved_tensors_range = (range_start, range_end)
        func_ctx.save_for_backward(*to_save)

        # Other context for backward pass
        func_ctx.backward_ops = backward_ops
        func_ctx.basic_ops = basic_ops
        func_ctx.basic_op_ctxs = basic_op_ctxs
        func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()

        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

        # Unflatten list of saved tensors
        saved_tensors = func_ctx.saved_tensors
        for ctx in basic_op_ctxs:
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
            ctx._saved_tensors_range = None
        del saved_tensors

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
            if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs):
                dx = None
                break

            # Backward op
            dx, fused_op_dparams = op.fuser_backward(
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
            )
            for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams):
                grad_params[idx] = basic_op_dparams
                basic_op_ctxs[idx].saved_tensors = None

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
            params = list(basic_ops[idx].parameters())
            if dparams is None:
                dparams = [None for _ in range(len(params))]
            else:
                dparams = list(dparams)
            if len(dparams) != len(params):
                raise RuntimeError(
                    f"Expected op {idx} to generate {len(params)} param grads, "
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

        # Update FP8 scaling factors
        if func_ctx.is_first_module and not is_graph_capturing():
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
            None,  # forward_ops
            None,  # backward_ops
            None,  # basic_ops
            None,  # basic_op_kwargs
            *grad_params_flat,  # params
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations
    fuse_ops: bool, default = `True`
        Whether to attempt fusing operations

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
        fuse_ops: bool = True,
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

        # Ops for forward and backward pass
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]
        self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
        self._backward_ops = list(reversed(self._forward_ops))

        # Fuse ops if needed
        if fuse_ops:
            self.fuse_ops()

    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
        ops = fuse_forward_linear_bias_activation(ops)
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
        return ops

    def fuse_ops(self) -> None:
        """Attempt to fuse operations"""
        self._forward_ops = self._fuse_forward_ops(self._forward_ops)
        self._backward_ops = self._fuse_backward_ops(self._backward_ops)

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
    ) -> torch.Tensor:

        # Initialization before forward pass
        for op in self._basic_ops:
            op.pre_forward()

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
            basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]

        # Flatten list of parameters
        params = []
        for op in self._basic_ops:
            params.extend(op.parameters())

        # Fuser forward pass
        return _OperationFuserAutogradFunction.apply(
            input,
            self._forward_ops,
            self._backward_ops,
            self._basic_ops,
            basic_op_kwargs,
            *params,
        )