fuser.py 17.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Manager class for a pipeline of fusible operations."""

from __future__ import annotations
Jan Bielak's avatar
Jan Bielak committed
8
from collections.abc import Callable, Iterable
9
from typing import Any, Optional
Jan Bielak's avatar
Jan Bielak committed
10
import itertools
11
12
13

import torch

Jan Bielak's avatar
Jan Bielak committed
14
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
15
16
17
18
19
from transformer_engine.pytorch.ops.op import (
    BasicOperation,
    FusibleOperation,
    OperationContext,
)
20
from transformer_engine.pytorch.ops.fused import (
Jan Bielak's avatar
Jan Bielak committed
21
    fuse_backward_activation_bias,
22
    fuse_backward_add_rmsnorm,
23
    fuse_backward_linear_add,
Jan Bielak's avatar
Jan Bielak committed
24
    fuse_backward_linear_scale,
25
    fuse_forward_linear_bias_activation,
26
    fuse_forward_linear_bias_add,
Jan Bielak's avatar
Jan Bielak committed
27
    fuse_forward_linear_scale_add,
28
29
    fuse_userbuffers_backward_linear,
    fuse_userbuffers_forward_linear,
30
)
31
32
33
34
from transformer_engine.pytorch.tensor.quantized_tensor import (
    prepare_for_saving,
    restore_from_saved,
)
35
36


37
38
39
40
41
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
    """Split tuple at index"""
    return t[:idx], t[idx:]


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None


def _is_graph_capturing() -> bool:
    """Whether function is called within `make_graphed_callables`

    Avoid circular import with lazy import.

    """
    global _is_graph_capturing_function
    if _is_graph_capturing_function is None:
        from ..graph import is_graph_capturing

        _is_graph_capturing_function = is_graph_capturing
    return _is_graph_capturing_function()


60
61
62
63
64
65
66
67
68
69
70
class _OperationFuserAutogradFunction(torch.autograd.Function):
    """Autograd function for a pipeline of operations

    Autograd must be done at the pipeline level since we may apply
    different fusions in the forward and backward passes.

    """

    # pylint: disable=unused-argument
    @staticmethod
    def forward(
Tim Moon's avatar
Tim Moon committed
71
        func_ctx: Optional[torch.autograd.function.FunctionCtx],
72
        input_: torch.Tensor,
73
        fuser: OperationFuser,
74
        basic_op_kwargs: list[dict[str, Any]],
Jan Bielak's avatar
Jan Bielak committed
75
        *params_and_extra_inputs: torch.Tensor,
76
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
77
78
79
80
81
82
83
84
        """Forward pass

        Parameters
        ----------
        func_ctx: torch.autograd.function.FunctionCtx
            Context for PyTorch autograd function
        input_: torch.Tensor
            Input to first operation in pipeline
85
86
        fuser: OperationFuser
            Container for the pipeline of operations to run
87
88
        basic_op_kwargs: list of dict
            Keyword arguments to BasicOperation
89
90
91
92
93
94
95
96
97
98
        *params_and_extra_inputs: torch.Tensor
            Other tensor inputs to include in autograd graph. Consists
            of parameter tensors, followed by extra operation inputs.

        Returns
        -------
        Output tensor(s). If none of the operations have any extra
        tensor outputs, then the pipeline's output tensor is returned.
        Otherwise, a tuple with the pipeline's output tensor and extra
        tensor outputs is returned.
99
100
101
102

        """

        # Operation autograd contexts
103
        basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
104

105
106
        # Mark input tensors as not deletable in backward
        for tensor in (input_,) + params_and_extra_inputs:
107
            tensor._do_not_clear = True
108

109
        # Unflatten list of parameters and extra tensor inputs
Jan Bielak's avatar
Jan Bielak committed
110
        extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
111
        basic_op_extra_inputs = []
112
        for op in fuser._basic_ops:
113
114
115
            xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

116
117
        # Apply forward ops
        x = input_
118
119
        extra_outputs = [None] * fuser._num_basic_ops
        for op, basic_op_idxs in fuser._forward_ops:
120

Jan Bielak's avatar
Jan Bielak committed
121
            # Set if backward op is required
Tim Moon's avatar
Tim Moon committed
122
            for idx in basic_op_idxs:
Jan Bielak's avatar
Jan Bielak committed
123
                basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
Tim Moon's avatar
Tim Moon committed
124

125
            # Forward op
126
            extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
127
            prev_op_idx = basic_op_idxs[0] - 1
128
            prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
Jan Bielak's avatar
Jan Bielak committed
129
130
131
            prev_op_grad_output_quantizer = None
            if prev_op is not None:
                prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
132
133
134
            next_op_idx = basic_op_idxs[-1] + 1
            next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
            next_op_input_quantizer = None
Jan Bielak's avatar
Jan Bielak committed
135
            if next_op is not None:
136
137
                next_op_input_quantizer = next_op.get_input_quantizer()

138
            x, fused_op_extra_outputs = op.fuser_forward(
139
140
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                x,
141
                basic_op_extra_inputs=extra_inputs,
Jan Bielak's avatar
Jan Bielak committed
142
                prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
143
                next_op_input_quantizer=next_op_input_quantizer,
144
                basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
145
            )
146
            for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
Tim Moon's avatar
Tim Moon committed
147
                for y in ys:
Jan Bielak's avatar
Jan Bielak committed
148
                    y.requires_grad_(idx >= fuser.first_op_requiring_backward)
149
                extra_outputs[idx] = ys
150

151
152
153
154
        # Flatten list of extra outputs
        extra_outputs_flat = []
        for idx, ys in enumerate(extra_outputs):
            ys = list(ys)
155
            num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
156
157
158
159
160
161
162
163
            if len(ys) != num_extra_outputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate "
                    "{num_extra_outputs} extra inputs, "
                    f"but got {len(ys)}"
                )
            extra_outputs_flat.extend(ys)

Tim Moon's avatar
Tim Moon committed
164
        # Save context for backward pass
165
        if func_ctx is not None:
Tim Moon's avatar
Tim Moon committed
166
167
168
169
170
171
172
173
174
175

            # Flatten list of saved tensors
            to_save = []
            for ctx in basic_op_ctxs:
                range_start = len(to_save)
                if ctx.to_save is not None:
                    to_save.extend(ctx.to_save)
                range_end = len(to_save)
                ctx.to_save = None
                ctx._saved_tensors_range = (range_start, range_end)
176
177

            # Save tensors for backward
178
179
180
            tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
            func_ctx.save_for_backward(*tensors_to_save)
            func_ctx.tensor_objects = tensor_objects
Tim Moon's avatar
Tim Moon committed
181

182
183
184
185
186
            # Whether to perform recipe update in backward pass
            is_first_module = False
            if fuser.first_op_requiring_backward < fuser._num_basic_ops:
                is_first_module = FP8GlobalStateManager.is_first_fp8_module()

Tim Moon's avatar
Tim Moon committed
187
            # Other context
188
189
            func_ctx.backward_ops = fuser._backward_ops
            func_ctx.basic_ops = fuser._basic_ops
Tim Moon's avatar
Tim Moon committed
190
            func_ctx.basic_op_ctxs = basic_op_ctxs
Jan Bielak's avatar
Jan Bielak committed
191
192
            func_ctx.basic_op_num_params = fuser._basic_op_num_params
            func_ctx.num_extra_inputs = fuser.num_extra_inputs
Tim Moon's avatar
Tim Moon committed
193
            func_ctx.num_extra_outputs = len(extra_outputs_flat)
194
            func_ctx.is_first_module = is_first_module
195

196
197
        # Mark output tensors as not deletable in backward
        for tensor in [x] + extra_outputs_flat:
198
            tensor._do_not_clear = True
199

Jan Bielak's avatar
Jan Bielak committed
200
        x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
201

202
203
        if extra_outputs_flat:
            return x, *extra_outputs_flat
204

205
206
207
208
209
210
211
        return x

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(
        func_ctx: Any,
        grad_output: torch.Tensor,
212
        *grad_extra_outputs: torch.Tensor,
213
214
215
216
217
218
219
220
    ) -> tuple[Optional[torch.Tensor], ...]:
        """Backward pass"""

        # Operations and autograd state
        backward_ops = func_ctx.backward_ops
        basic_ops = func_ctx.basic_ops
        basic_op_ctxs = func_ctx.basic_op_ctxs

221
        # Restore saved tensors
222
        saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
223

224
225
        # Unflatten list of saved tensors
        for ctx in basic_op_ctxs:
226
            ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
227
            ctx._saved_tensors_range = None
228
229
230
231
232
233
234
235
236
237
238

        # Unflatten list of extra tensor output grads
        if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
            raise ValueError(
                f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
                f"but got {len(grad_extra_outputs)}"
            )
        basic_op_grad_extra_outputs = []
        for op in basic_ops:
            dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
            basic_op_grad_extra_outputs.append(dys)
239
240
241
242

        # Apply backward ops
        dx = grad_output
        grad_params = [None for _ in range(len(basic_ops))]
243
        grad_extra_inputs = [None for _ in range(len(basic_ops))]
244
245
246
        for op, basic_op_idxs in backward_ops:

            # Stop if no more gradients are required
Tim Moon's avatar
Tim Moon committed
247
            if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
248
249
250
251
                dx = None
                break

            # Backward op
252
253
            grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
            dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
254
255
                [basic_op_ctxs[idx] for idx in basic_op_idxs],
                dx,
256
                basic_op_grad_extra_outputs=grad_extra_outputs,
257
            )
258
259
            for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
                grad_params[idx] = dparams
260
                basic_op_ctxs[idx].saved_tensors = None
261
262
            for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
                grad_extra_inputs[idx] = dxs
263
264
265
266

        # Flatten list of parameter gradients
        grad_params_flat = []
        for idx, dparams in enumerate(grad_params):
267
            num_params = func_ctx.basic_op_num_params[idx]
268
            if dparams is None:
269
                dparams = [None for _ in range(num_params)]
270
271
            else:
                dparams = list(dparams)
272
            if len(dparams) != num_params:
273
                raise RuntimeError(
274
                    f"Expected op {idx} to generate {num_params} param grads, "
275
276
277
278
                    f"but got {len(dparams)}"
                )
            grad_params_flat.extend(dparams)

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        # Flatten list of parameter gradients
        grad_extra_inputs_flat = []
        for idx, dxs in enumerate(grad_extra_inputs):
            num_extra_inputs = basic_ops[idx].num_extra_inputs
            if dxs is None:
                dxs = [None for _ in range(num_extra_inputs)]
            else:
                dxs = list(dxs)
            if len(dxs) != num_extra_inputs:
                raise RuntimeError(
                    f"Expected op {idx} to generate grads "
                    f"for {num_extra_inputs} extra inputs, "
                    f"but got {len(dxs)}"
                )
            grad_extra_inputs_flat.extend(dxs)

295
        # Update FP8 scaling factors
296
        if func_ctx.is_first_module and not _is_graph_capturing():
297
298
299
300
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        return (
            dx,  # input_
301
            None,  # fuser
302
            None,  # basic_op_kwargs
303
304
            *grad_params_flat,
            *grad_extra_inputs_flat,
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        )


class OperationFuser:
    """Manages forward and backward passes for a pipeline of operations

    Parameters
    ----------
    ops: list of FusibleOperation
        Pipeline of operations

    """

    def __init__(
        self,
        ops: list[FusibleOperation],
    ) -> None:

        # Get list of basic operations
        basic_ops = []
        for op in ops:
            if op.is_fused_op:
                basic_ops.extend(op.basic_ops)
            else:
                basic_ops.append(op)
        self._num_basic_ops: int = len(basic_ops)
        self._basic_ops: list[BasicOperation] = basic_ops

333
        # Number of extra tensor inputs
Jan Bielak's avatar
Jan Bielak committed
334
335
        self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
        self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
336

Jan Bielak's avatar
Jan Bielak committed
337
        # Ops for forward and backward pass, will be populated in fuse_ops
338
339
340
        self._forward_ops: list[tuple[FusibleOperation, list[int]]]
        self._backward_ops: list[tuple[FusibleOperation, list[int]]]

Jan Bielak's avatar
Jan Bielak committed
341
342
343
344
        # Cache and detect change of state relevant for fusing operations
        self.recipe_type = None
        self.first_op_requiring_backward = 0
        self._last_amax_history_len = 0
345

346
        # Flatten list of parameters
Jan Bielak's avatar
Jan Bielak committed
347
348
349
        self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
        self._basic_op_num_params = list(map(len, self._basic_op_params))
        self._flat_basic_op_params = sum(self._basic_op_params, [])
350

351
352
353
354
    @classmethod
    def _fuse_forward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
355
        recipe: Optional[Recipe],  # pylint: disable=unused-argument
356
357
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in forward pass"""
358
        ops = fuse_userbuffers_forward_linear(ops)
359
        ops = fuse_forward_linear_bias_add(ops)
360
        ops = fuse_forward_linear_bias_activation(ops)
Jan Bielak's avatar
Jan Bielak committed
361
        ops = fuse_forward_linear_scale_add(ops)
362
363
364
365
366
367
        return ops

    @classmethod
    def _fuse_backward_ops(
        cls,
        ops: list[tuple[FusibleOperation, list[int]]],
368
        recipe: Optional[Recipe],
369
370
    ) -> list[tuple[FusibleOperation, list[int]]]:
        """Attempt to fuse operations in backward pass"""
371
        ops = fuse_userbuffers_backward_linear(ops)
372
        ops = fuse_backward_linear_add(ops)
Jan Bielak's avatar
Jan Bielak committed
373
        ops = fuse_backward_linear_scale(ops)
Jan Bielak's avatar
Jan Bielak committed
374
        ops = fuse_backward_activation_bias(ops, recipe)
375
        ops = fuse_backward_add_rmsnorm(ops)
376
377
        return ops

Jan Bielak's avatar
Jan Bielak committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    def maybe_fuse_ops(
        self,
        is_grad_enabled: bool,
        recipe: Optional[Recipe],
        input_: torch.Tensor,
        extra_inputs: list[Iterable[torch.Tensor]],
    ):
        """Attempt to fuse operations if neccesary"""

        # Determine which basic ops require backward
        if not is_grad_enabled:
            first_op_requiring_backward = self._num_basic_ops
        elif input_.requires_grad:
            first_op_requiring_backward = 0
        else:
            first_op_requiring_backward = self._num_basic_ops
            for op_idx in range(self._num_basic_ops):
                op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
                if any(tensor.requires_grad for tensor in op_inputs):
                    first_op_requiring_backward = op_idx
                    break

        # Early exit if fusion parameters haven't changed
401
        need_reset = False
Jan Bielak's avatar
Jan Bielak committed
402
403
        recipe_type = type(recipe)
        fusion_params = (recipe_type, first_op_requiring_backward)
404
405
406
407
408
409
410
411
412
413
414
        if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
            # Recipe type or grad requirmenets have changed
            need_reset = True
        elif (
            recipe is not None
            and recipe.delayed()
            and self._last_amax_history_len != recipe.amax_history_len
        ):
            # FP8 delayed scaling has changed amax history length
            need_reset = True
        if not need_reset:
Jan Bielak's avatar
Jan Bielak committed
415
416
            return

417
418
419
420
421
422
        # Reset recipe state
        for op in self._basic_ops:
            op.reset_recipe_state(recipe=recipe)

        # Check if this is the first iteration
        if self.recipe_type is None:
Jan Bielak's avatar
Jan Bielak committed
423
            for op in self._basic_ops:
424
                op.pre_first_fuser_forward()
Jan Bielak's avatar
Jan Bielak committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

        # Prepare basic op lists for fusions
        forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
        backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))

        # Fuse ops
        self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
        self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)

        # Save current fusion params
        self.recipe_type, self.first_op_requiring_backward = fusion_params

        # Save amax history length
        if isinstance(recipe, DelayedScaling):
            self._last_amax_history_len = recipe.amax_history_len
        else:
            self._last_amax_history_len = 0
442
443
444
445

    def __call__(
        self,
        input: torch.Tensor,  # pylint: disable=redefined-builtin
446
        *extra_inputs: torch.Tensor,
447
        basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
448
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
449
        # Verify extra input count
Jan Bielak's avatar
Jan Bielak committed
450
        if len(extra_inputs) != self.num_extra_inputs:
451
            raise ValueError(
Jan Bielak's avatar
Jan Bielak committed
452
                f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
453
            )
454
455
456

        # Canonicalize op kwargs
        if basic_op_kwargs is None:
457
            basic_op_kwargs = [{}] * self._num_basic_ops
458

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        # Unflatten list of extra tensor inputs
        extra_inputs_copy = list(extra_inputs)
        basic_op_extra_inputs = []
        for op in self._basic_ops:
            xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs)
            basic_op_extra_inputs.append(xs)

        # Get environment state
        recipe = None
        if FP8GlobalStateManager.is_fp8_enabled():
            recipe = FP8GlobalStateManager.get_fp8_recipe()
        is_grad_enabled = torch.is_grad_enabled()

        # Attempt to fuse operations if neccesary
        self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)

475
        # Fuser forward pass
476
        if is_grad_enabled:
Tim Moon's avatar
Tim Moon committed
477
478
479
480
481
482
            forward_func = _OperationFuserAutogradFunction.apply
            args = []
        else:
            forward_func = _OperationFuserAutogradFunction.forward
            args = [None]
        args += (
483
            input,
484
            self,
485
            basic_op_kwargs,
Jan Bielak's avatar
Jan Bielak committed
486
            *self._flat_basic_op_params,
487
            *extra_inputs,
488
        )
Tim Moon's avatar
Tim Moon committed
489
        return forward_func(*args)