fuser.py 17.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
Jan Bielak's avatar
Jan Bielak committed
8
from collections.abc import Callable, Iterable
9
from typing import Any, Optional
Jan Bielak's avatar
Jan Bielak committed
10
import itertools
11
12
13

import torch

Jan Bielak's avatar
Jan Bielak committed
14
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
15
16
17
18
19
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
20
from transformer_engine.pytorch.ops.fused import (
Jan Bielak's avatar
Jan Bielak committed
21
    fuse_backward_activation_bias,
22
    fuse_backward_linear_add,
23
    fuse_forward_linear_bias_activation,
24
    fuse_forward_linear_bias_add,
25
26
    fuse_userbuffers_backward_linear,
    fuse_userbuffers_forward_linear,
27
)
28
29
30
31
from transformer_engine.pytorch.tensor.quantized_tensor import (
    prepare_for_saving,
    restore_from_saved,
)
32
33


34
35
36
37
38
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


57
58
59
60
61
62
63
64
65
66
67
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
68
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
69
        input_: torch.Tensor,
70
        fuser: OperationFuser,
71
        basic_op_kwargs: list[dict[str, Any]],
Jan Bielak's avatar
Jan Bielak committed
72
        *params_and_extra_inputs: torch.Tensor,
73
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
74
75
76
77
78
79
80
81
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
82
83
        fuser: OperationFuser
            Container for the pipeline of operations to run
84
85
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
86
87
88
89
90
91
92
93
94
95
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
96
97
98
99

        """

        # Operation autograd contexts
100
        basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
101

102
103
104
105
        # Mark input tensors as not deletable in backward
        for tensor in (input_,) + params_and_extra_inputs:
            tensor.do_not_clear = True

106
        # Unflatten list of parameters and extra tensor inputs
Jan Bielak's avatar
Jan Bielak committed
107
        extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
108
        basic_op_extra_inputs = []
109
        for op in fuser._basic_ops:
110
111
112
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

Jan Bielak's avatar
Jan Bielak committed
113
114
115
116
117
118
119
120
        # Get environment state
        with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
        recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
        is_grad_enabled = func_ctx is not None

        # Attempt to fuse operations if neccesary
        fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs)

121
122
        # Apply forward ops
        x = input_
123
124
        extra_outputs = [None] * fuser._num_basic_ops
        for op, basic_op_idxs in fuser._forward_ops:
125

Jan Bielak's avatar
Jan Bielak committed
126
            # Set if backward op is required
Tim Moon's avatar
Tim Moon committed
127
            for idx in basic_op_idxs:
Jan Bielak's avatar
Jan Bielak committed
128
                basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
Tim Moon's avatar
Tim Moon committed
129

130
            # Forward op
131
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
132
            prev_op_idx = basic_op_idxs[0] - 1
133
            prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
Jan Bielak's avatar
Jan Bielak committed
134
135
136
            prev_op_grad_output_quantizer = None
            if prev_op is not None:
                prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
137
138
139
            next_op_idx = basic_op_idxs[-1] + 1
            next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
            next_op_input_quantizer = None
Jan Bielak's avatar
Jan Bielak committed
140
            if next_op is not None:
141
142
                next_op_input_quantizer = next_op.get_input_quantizer()

143
            x, fused_op_extra_outputs = op.fuser_forward(
144
145
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
146
                basic_op_extra_inputs=extra_inputs,
Jan Bielak's avatar
Jan Bielak committed
147
                prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
148
                next_op_input_quantizer=next_op_input_quantizer,
149
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
150
            )
151
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
152
                for y in ys:
Jan Bielak's avatar
Jan Bielak committed
153
                    y.requires_grad_(idx >= fuser.first_op_requiring_backward)
154
                extra_outputs[idx] = ys
155

156
157
158
159
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
160
            num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
161
162
163
164
165
166
167
168
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
169
170
171
172
173
174
175
176
177
178
179
180
        # Save context for backward pass
        if is_grad_enabled:

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
181
182
183
184
185
186
187
188

            # Save tensors for backward
            if with_quantized_compute:
                tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
                func_ctx.save_for_backward(*tensors_to_save)
                func_ctx.tensor_objects = tensor_objects
            else:
                func_ctx.save_for_backward(*to_save)
Tim Moon's avatar
Tim Moon committed
189
190

            # Other context
191
192
            func_ctx.backward_ops = fuser._backward_ops
            func_ctx.basic_ops = fuser._basic_ops
Tim Moon's avatar
Tim Moon committed
193
            func_ctx.basic_op_ctxs = basic_op_ctxs
Jan Bielak's avatar
Jan Bielak committed
194
195
            func_ctx.basic_op_num_params = fuser._basic_op_num_params
            func_ctx.num_extra_inputs = fuser.num_extra_inputs
Tim Moon's avatar
Tim Moon committed
196
197
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
            func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
198
            func_ctx.with_quantized_compute = with_quantized_compute
199

200
201
202
203
        # Mark output tensors as not deletable in backward
        for tensor in [x] + extra_outputs_flat:
            tensor.do_not_clear = True

Jan Bielak's avatar
Jan Bielak committed
204
        x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
205

206
207
        if extra_outputs_flat:
            return x, *extra_outputs_flat
208

209
210
211
212
213
214
215
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
216
        *grad_extra_outputs: torch.Tensor,
217
218
219
220
221
222
223
224
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

225
226
227
228
229
230
        # Restore saved tensors
        if func_ctx.with_quantized_compute:
            saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
        else:
            saved_tensors = func_ctx.saved_tensors

231
232
        # Unflatten list of saved tensors
        for ctx in basic_op_ctxs:
233
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
234
            ctx._saved_tensors_range = None
235
236
237
238
239
240
241
242
243
244
245

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
246
247
248
249

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
250
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
251
252
253
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
254
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
255
256
257
258
                dx = None
                break

            # Backward op
259
260
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
261
262
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
263
                basic_op_grad_extra_outputs=grad_extra_outputs,
264
            )
265
266
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
267
                basic_op_ctxs[idx].saved_tensors = None
268
269
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
270
271
272
273

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
274
            num_params = func_ctx.basic_op_num_params[idx]
275
            if dparams is None:
276
                dparams = [None for _ in range(num_params)]
277
278
            else:
                dparams = list(dparams)
279
            if len(dparams) != num_params:
280
                raise RuntimeError(
281
                    f"Expected op {idx} to generate {num_params} param grads, "
282
283
284
285
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

302
        # Update FP8 scaling factors
303
        if func_ctx.is_first_module and not _is_graph_capturing():
304
305
306
307
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
308
            None,  # fuser
309
            None,  # basic_op_kwargs
310
311
            *grad_params_flat,
            *grad_extra_inputs_flat,
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

340
        # Number of extra tensor inputs
Jan Bielak's avatar
Jan Bielak committed
341
342
        self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
        self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
343

Jan Bielak's avatar
Jan Bielak committed
344
        # Ops for forward and backward pass, will be populated in fuse_ops
345
346
347
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]

Jan Bielak's avatar
Jan Bielak committed
348
349
350
351
        # Cache and detect change of state relevant for fusing operations
        self.recipe_type = None
        self.first_op_requiring_backward = 0
        self._last_amax_history_len = 0
352

353
        # Flatten list of parameters
Jan Bielak's avatar
Jan Bielak committed
354
355
356
        self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
        self._basic_op_num_params = list(map(len, self._basic_op_params))
        self._flat_basic_op_params = sum(self._basic_op_params, [])
357

358
359
360
361
    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
362
        recipe: Optional[Recipe],  # pylint: disable=unused-argument
363
364
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
365
        ops = fuse_userbuffers_forward_linear(ops)
366
        ops = fuse_forward_linear_bias_add(ops)
367
368
369
370
371
372
373
        ops = fuse_forward_linear_bias_activation(ops)
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
374
        recipe: Optional[Recipe],
375
376
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
377
        ops = fuse_userbuffers_backward_linear(ops)
378
        ops = fuse_backward_linear_add(ops)
Jan Bielak's avatar
Jan Bielak committed
379
        ops = fuse_backward_activation_bias(ops, recipe)
380
381
        return ops

Jan Bielak's avatar
Jan Bielak committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    def maybe_fuse_ops(
        self,
        is_grad_enabled: bool,
        recipe: Optional[Recipe],
        input_: torch.Tensor,
        extra_inputs: list[Iterable[torch.Tensor]],
    ):
        """Attempt to fuse operations if neccesary"""

        # Determine which basic ops require backward
        if not is_grad_enabled:
            first_op_requiring_backward = self._num_basic_ops
        elif input_.requires_grad:
            first_op_requiring_backward = 0
        else:
            first_op_requiring_backward = self._num_basic_ops
            for op_idx in range(self._num_basic_ops):
                op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
                if any(tensor.requires_grad for tensor in op_inputs):
                    first_op_requiring_backward = op_idx
                    break

        # Early exit if fusion parameters haven't changed
405
        need_reset = False
Jan Bielak's avatar
Jan Bielak committed
406
407
        recipe_type = type(recipe)
        fusion_params = (recipe_type, first_op_requiring_backward)
408
409
410
411
412
413
414
415
416
417
418
        if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
            # Recipe type or grad requirmenets have changed
            need_reset = True
        elif (
            recipe is not None
            and recipe.delayed()
            and self._last_amax_history_len != recipe.amax_history_len
        ):
            # FP8 delayed scaling has changed amax history length
            need_reset = True
        if not need_reset:
Jan Bielak's avatar
Jan Bielak committed
419
420
            return

421
422
423
424
425
426
        # Reset recipe state
        for op in self._basic_ops:
            op.reset_recipe_state(recipe=recipe)

        # Check if this is the first iteration
        if self.recipe_type is None:
Jan Bielak's avatar
Jan Bielak committed
427
            for op in self._basic_ops:
428
                op.pre_first_fuser_forward()
Jan Bielak's avatar
Jan Bielak committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445

        # Prepare basic op lists for fusions
        forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
        backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))

        # Fuse ops
        self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
        self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)

        # Save current fusion params
        self.recipe_type, self.first_op_requiring_backward = fusion_params

        # Save amax history length
        if isinstance(recipe, DelayedScaling):
            self._last_amax_history_len = recipe.amax_history_len
        else:
            self._last_amax_history_len = 0
446
447
448
449

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
450
        *extra_inputs: torch.Tensor,
451
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
452
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
453
        # Verify extra input count
Jan Bielak's avatar
Jan Bielak committed
454
        if len(extra_inputs) != self.num_extra_inputs:
455
            raise ValueError(
Jan Bielak's avatar
Jan Bielak committed
456
                f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
457
            )
458
459
460

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
461
            basic_op_kwargs = [{}] * self._num_basic_ops
462
463

        # Fuser forward pass
Jan Bielak's avatar
Jan Bielak committed
464
        if torch.is_grad_enabled():
Tim Moon's avatar
Tim Moon committed
465
466
467
468
469
470
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
471
            input,
472
            self,
473
            basic_op_kwargs,
Jan Bielak's avatar
Jan Bielak committed
474
            *self._flat_basic_op_params,
475
            *extra_inputs,
476
        )
Tim Moon's avatar
Tim Moon committed
477
        return forward_func(*args)