fuser.py 15.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
8
from collections.abc import Callable
9
10
11
12
13
14
15
16
17
18
from typing import Any, Optional

import torch

from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
19
20
from transformer_engine.pytorch.ops.fused import (
    fuse_backward_linear_add,
21
    fuse_forward_linear_bias_activation,
22
    fuse_forward_linear_bias_add,
23
24
    fuse_userbuffers_backward_linear,
    fuse_userbuffers_forward_linear,
25
)
26
27
28
29
from transformer_engine.pytorch.tensor.quantized_tensor import (
    prepare_for_saving,
    restore_from_saved,
)
30
31


32
33
34
35
36
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


55
56
57
58
59
60
61
62
63
64
65
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
66
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
67
        input_: torch.Tensor,
68
        fuser: OperationFuser,
69
        basic_op_kwargs: list[dict[str, Any]],
Tim Moon's avatar
Tim Moon committed
70
        is_grad_enabled: bool,
71
72
        *params_and_extra_inputs: torch.nn.Parameter,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
73
74
75
76
77
78
79
80
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
81
82
        fuser: OperationFuser
            Container for the pipeline of operations to run
83
84
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
85
86
        is_grad_enabled: bool
            Should context be saved for backward
87
88
89
90
91
92
93
94
95
96
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
97
98
99
100

        """

        # Operation autograd contexts
101
        basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
102

103
        # Unflatten list of parameters and extra tensor inputs
104
        extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
105
        basic_op_extra_inputs = []
106
        for op in fuser._basic_ops:
107
108
109
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

110
111
        # Apply forward ops
        x = input_
Tim Moon's avatar
Tim Moon committed
112
        requires_grad = is_grad_enabled and x.requires_grad
113
114
        extra_outputs = [None] * fuser._num_basic_ops
        for op, basic_op_idxs in fuser._forward_ops:
115

Tim Moon's avatar
Tim Moon committed
116
117
118
119
120
121
122
123
124
            # Check if backward op is required
            if is_grad_enabled:
                if not requires_grad:
                    requires_grad = any(param.requires_grad for param in op.parameters())
                if not requires_grad:
                    requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
            for idx in basic_op_idxs:
                basic_op_ctxs[idx].requires_grad = requires_grad

125
            # Forward op
126
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
127
128
129
130
131
132
133
134
135
136
137
138
            prev_op_idx = basic_op_idxs[0] - 1
            prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx > 0 else None
            prev_op_grad_input_quantizer = None
            if prev_op is not None:
                prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
            next_op_idx = basic_op_idxs[-1] + 1
            next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
            next_op_input_quantizer = None
            if next_op is not None:
                next_op_input_quantizer = next_op.get_input_quantizer()
            is_first_op = prev_op is None

139
            x, fused_op_extra_outputs = op.fuser_forward(
140
141
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
142
                basic_op_extra_inputs=extra_inputs,
143
144
145
                prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
                next_op_input_quantizer=next_op_input_quantizer,
                is_first_op=is_first_op,
146
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
147
            )
148
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
149
                for y in ys:
150
                    y.requires_grad_(requires_grad)
151
                extra_outputs[idx] = ys
152

153
154
155
156
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
157
            num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
158
159
160
161
162
163
164
165
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
166
167
168
169
170
171
172
173
174
175
176
177
        # Save context for backward pass
        if is_grad_enabled:

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
178
179
180
181
182
183
184
185
186

            # Save tensors for backward
            with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
            if with_quantized_compute:
                tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
                func_ctx.save_for_backward(*tensors_to_save)
                func_ctx.tensor_objects = tensor_objects
            else:
                func_ctx.save_for_backward(*to_save)
Tim Moon's avatar
Tim Moon committed
187
188

            # Other context
189
190
            func_ctx.backward_ops = fuser._backward_ops
            func_ctx.basic_ops = fuser._basic_ops
Tim Moon's avatar
Tim Moon committed
191
            func_ctx.basic_op_ctxs = basic_op_ctxs
192
193
            func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
            func_ctx.num_extra_inputs = fuser._num_extra_inputs
Tim Moon's avatar
Tim Moon committed
194
195
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
            func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
196
            func_ctx.with_quantized_compute = with_quantized_compute
197

198
199
        if extra_outputs_flat:
            return x, *extra_outputs_flat
200
201
202

        x.requires_grad_(requires_grad)

203
204
205
206
207
208
209
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
210
        *grad_extra_outputs: torch.Tensor,
211
212
213
214
215
216
217
218
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

219
220
221
222
223
224
        # Restore saved tensors
        if func_ctx.with_quantized_compute:
            saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
        else:
            saved_tensors = func_ctx.saved_tensors

225
226
        # Unflatten list of saved tensors
        for ctx in basic_op_ctxs:
227
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
228
            ctx._saved_tensors_range = None
229
230
231
232
233
234
235
236
237
238
239

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
240
241
242
243

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
244
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
245
246
247
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
248
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
249
250
251
252
                dx = None
                break

            # Backward op
253
254
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
255
256
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
257
                basic_op_grad_extra_outputs=grad_extra_outputs,
258
            )
259
260
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
261
                basic_op_ctxs[idx].saved_tensors = None
262
263
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
264
265
266
267

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
268
            num_params = func_ctx.basic_op_num_params[idx]
269
            if dparams is None:
270
                dparams = [None for _ in range(num_params)]
271
272
            else:
                dparams = list(dparams)
273
            if len(dparams) != num_params:
274
                raise RuntimeError(
275
                    f"Expected op {idx} to generate {num_params} param grads, "
276
277
278
279
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

296
        # Update FP8 scaling factors
297
        if func_ctx.is_first_module and not _is_graph_capturing():
298
299
300
301
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
302
            None,  # fuser
303
            None,  # basic_op_kwargs
Tim Moon's avatar
Tim Moon committed
304
            None,  # is_grad_enabled
305
306
            *grad_params_flat,
            *grad_extra_inputs_flat,
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations
    fuse_ops: bool, default = `True`
        Whether to attempt fusing operations

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
        fuse_ops: bool = True,
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

338
339
340
        # Number of extra tensor inputs
        self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)

341
342
343
344
345
346
        # Ops for forward and backward pass
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]
        self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
        self._backward_ops = list(reversed(self._forward_ops))

347
348
349
        # Flag for checking if this is the first iteration
        self._is_first_forward = True

350
351
352
353
        # Fuse ops if needed
        if fuse_ops:
            self.fuse_ops()

354
355
356
357
        # Flatten list of parameters
        self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
        self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]

358
359
360
361
362
363
    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
364
        ops = fuse_userbuffers_forward_linear(ops)
365
        ops = fuse_forward_linear_bias_add(ops)
366
367
368
369
370
371
372
373
374
        ops = fuse_forward_linear_bias_activation(ops)
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
375
        ops = fuse_userbuffers_backward_linear(ops)
376
        ops = fuse_backward_linear_add(ops)
377
378
379
380
381
382
383
384
385
386
        return ops

    def fuse_ops(self) -> None:
        """Attempt to fuse operations"""
        self._forward_ops = self._fuse_forward_ops(self._forward_ops)
        self._backward_ops = self._fuse_backward_ops(self._backward_ops)

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
387
        *extra_inputs: torch.Tensor,
388
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
389
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
390
391
392
393
394
        # Verify extra input count
        if len(extra_inputs) != self._num_extra_inputs:
            raise ValueError(
                f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
            )
395
396

        # Initialization before forward pass
397
398
399
400
401
402
        if self._is_first_forward:
            with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
            recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
            for op in self._basic_ops:
                op.pre_first_forward(recipe=recipe)
            self._is_first_forward = False
403
404
405

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
406
            basic_op_kwargs = [{}] * self._num_basic_ops
407
408

        # Fuser forward pass
Tim Moon's avatar
Tim Moon committed
409
410
411
412
413
414
415
416
        is_grad_enabled = torch.is_grad_enabled()
        if is_grad_enabled:
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
417
            input,
418
            self,
419
            basic_op_kwargs,
Tim Moon's avatar
Tim Moon committed
420
            is_grad_enabled,
421
            *self._basic_op_params,
422
            *extra_inputs,
423
        )
Tim Moon's avatar
Tim Moon committed
424
        return forward_func(*args)