cpu_offload.py 20.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Functionality for CPU offloading of tensors saved for backward pass."""
6
from __future__ import annotations
7
from contextlib import nullcontext
8
9
from typing import Any, Dict, Optional

10
11
import torch

12
from .tensor.float8_tensor import Float8Tensor
13

14
__all__ = ["get_cpu_offload_context"]
15
16
17
18

CPUOffloadEnabled = False


19
20
21
22
23
24
25
26
27
28
29
30
31
32
def set_offloading_param(tensor, param_name, value):
    """Set the type of the offloading needed for a tensor."""
    assert param_name in ["weight_offloading", "activation_offloading"]
    if tensor is None:
        return
    if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
        setattr(tensor, param_name, value)
    else:
        data_tensors = tensor.get_data_tensors()
        for tensor in data_tensors:
            if tensor is not None:
                setattr(tensor, param_name, value)


33
34
35
36
37
def is_cpu_offload_enabled() -> bool:
    """Check if CPU offloading is currently enabled."""
    return CPUOffloadEnabled


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class CpuOffloadSavedTensorHook:
    """Contex-manager that executes a pair of pack/unpack hooks for saved tensors.

    In this context, the ``on_save_for_backward`` method will be called every time
    a tensor is saved for backward (this includes intermediary results saved using
    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
    also those recorded by a PyTorch-defined operation).

    The ``on_get_saved_tensors`` method will be called when the backward function
    of this op attempts to retrieve the saved tensor from context (this includes
    :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
    as input the return value of the ``on_save_for_backward``, and is meant to return
    an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
    size, device and element values.

    Example:

        >>> import torch
        >>> from typing import Any
        >>>
        >>> class DummyHook(CpuOffloadSavedTensorHook):
        ...
        ...     def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        ...         logging.info("On save", tensor)
        ...         return (tensor,)
        ...
        ...     def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        ...         logging.info("On get", saved_state)
        ...         tensor, = saved_state
        ...         return tensor
        ...
        >>> a = torch.ones(5, requires_grad=True)
        >>> b = torch.ones(5, requires_grad=True) * 2
        >>> with DummyHook():
        ...     y = a * b
        ...
        On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
        On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
        >>> y.sum().backward()
        On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
        On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)

    """

    def __init__(self) -> None:
        self.inside_context = False

    def __enter__(self):
        global CPUOffloadEnabled
        CPUOffloadEnabled = True

        self.inside_context = True
        torch._C._autograd._push_saved_tensors_default_hooks(
91
92
            self.on_save_for_backward, self.on_get_saved_tensor
        )
93
94
95
96
97
98
99
100
101
102

    def __exit__(self, *args: Any):
        global CPUOffloadEnabled
        CPUOffloadEnabled = False

        self.inside_context = False
        torch._C._autograd._pop_saved_tensors_default_hooks()

    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        """On save for backward."""
103
104
105
106
107
        raise NotImplementedError(
            "`on_save_for_backward: Callable[[torch.Tensor], Any]`"
            "is not implemented in CpuOffloadHook class. Inherit "
            "this class and implement your custom hooks"
        )
108
109
110

    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        """On get saved tensor."""
111
112
113
114
115
        raise NotImplementedError(
            "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
            "is not implemented in CpuOffloadHook class. Inherit "
            "this class and implement your custom hooks"
        )
116
117
118
119
120
121
122
123
124


class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
    """Context-manager that offloads/recovers tensors through an offload hander.

    The hook just offloads/recovers the tensor object to the handler through `tensor_push`
    and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
    or prefetching timing is transparent to this hook.
    """
125

126
127
128
    def __init__(
        self,
        offload_handler: OffloadHandler,
129
        handler_extra_kwargs: Optional[Dict[str, Any]] = None,
130
131
132
133
134
135
        debug: bool = False,
    ) -> None:
        if handler_extra_kwargs is None:
            handler_extra_kwargs = {}
        self.debug: bool = debug
        self.offload_handler: OffloadHandler = offload_handler
136
        self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
137
138
139
        super().__init__()

    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
140
        retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
141
142
143
        return retrieve_identifier

    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
144
        tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
145
146
147
148
149
        return tensor


class OffloadHandler:
    """A base class for CPU offload-handler."""
150

151
152
153
154
155
    def __init__(self) -> None:
        pass

    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
        """Tensor push."""
156
157
158
159
        raise NotImplementedError(
            "`tensor_push is not implented in OffloadHandler class. "
            "Inherit this class and implement your custom tensor_push."
        )
160
161
162

    def tensor_pop(self, tensor_tag: Any, **kwargs):
        """Tensor pop."""
163
164
165
166
        raise NotImplementedError(
            "`tensor_pop is not implented in OffloadHandler class. "
            "Inherit this class and implement your custom tensor_pop."
        )
167
168
169
170
171
172
173
174


class GroupCommitFunction(torch.autograd.Function):
    """this is a dummy op with output identical to input.
    However, it is necessary for marking a timepoint for offload handler to
    accomplish all synchronizations. Implementing it as a function is necessary
    because we need to actions in both forward and backward.
    """
175

176
177
    @staticmethod
    def forward(ctx, tensor, cpu_offload_handler):
178
        # pylint: disable=missing-function-docstring
179
180
181
182
183
184
185
        cpu_offload_handler.on_group_commit_forward()
        ctx.cpu_offload_handler = cpu_offload_handler
        # return the identical tensor
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
186
        # pylint: disable=missing-function-docstring
187
188
189
190
191
192
193
194
195
196
197
198
199
        cpu_offload_handler = ctx.cpu_offload_handler
        cpu_offload_handler.on_group_commit_backward()
        return grad_output, None


group_prefetch_offload_commit = GroupCommitFunction.apply


class SynchronizedGroupOffloadHandler(OffloadHandler):
    """Offload Handler that offloads/reloads in a synchronized way.
    The device-to-host and host-to-device copying happen in the same stream
    as the computation kernels, thus the copying will block computation.
    """
200
201
202
203

    def __init__(
        self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False
    ) -> None:
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        super().__init__()

        self.num_offload_group = num_offload_group
        self.tensor_need_offloading_checker = tensor_need_offloading_checker
        self.debug = debug

        self.groupid_reset()

    def groupid_reset(self):
        """Groupid reset."""
        # Data structures to label saved tensors and book-keep their cpu copies.
        # Currently, on push, create a new cpu tensor and copies; on pop, copies
        # the tensor back to gpu and deletes the cpu tensor.
        # These will increment whenever `group_commit()` is invoked
        self.current_group, self.tensor_count_current_group = (0, 0)
219
        self.torch_tensor_count = 0
220
221
222
223
224
        self.tensor_tag_to_state = {}

    def on_group_commit_forward(self):
        """On group commit forward."""
        # finishing up with updating current group and tensor count
225
226
        self.current_group += 1  # increment
        self.tensor_count_current_group = 0  # reset
227
228
229
230
231
232
233
234
235
236
237
238

    def on_group_commit_backward(self):
        """On group commit backward."""
        self.current_group -= 1
        assert self.current_group >= 0

    @staticmethod
    def offload(src_tensor, pin_memory=True):
        """Offload."""
        fp8_offload = isinstance(src_tensor, Float8Tensor)

        cpu_backup = torch.empty(
239
240
241
242
243
244
            src_tensor.size(),
            dtype=torch.uint8 if fp8_offload else src_tensor.dtype,
            layout=src_tensor.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

        if fp8_offload:
            cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup)

        cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
        state = (src_tensor.device, cpu_backup)
        return state

    @staticmethod
    def reload(state, non_blocking=None):
        """Reload."""
        dev, cpu_backup = state
        if non_blocking is None:
            non_blocking = cpu_backup.is_pinned()
        return cpu_backup.to(dev, non_blocking=non_blocking)

    def tensor_push(self, tensor: torch.Tensor, **kwargs):
        """Tensor push."""
        # obtain a unique tensor tag
        tensor_tag = (self.current_group, self.tensor_count_current_group)
        self.tensor_count_current_group += 1
        assert tensor_tag not in self.tensor_tag_to_state
267
268
269
        if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
            tensor
        ):
270
271
272
273
274
            state = SynchronizedGroupOffloadHandler.offload(tensor)
            self.tensor_tag_to_state[tensor_tag] = state
        else:
            # will be offloaded together after group commit
            self.tensor_tag_to_state[tensor_tag] = tensor
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        return tensor_tag

    def tensor_pop(self, tensor_tag, **kwargs):
        """Tensor pop."""
        assert tensor_tag in self.tensor_tag_to_state
        state = self.tensor_tag_to_state.pop(tensor_tag)
        if isinstance(state, tuple):
            tensor = SynchronizedGroupOffloadHandler.reload(state)
        else:
            tensor = state
        return tensor


class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
    """Compared to synchronize, this uses more memory because of the buffer but
    achieves better performance due to the overlapping. D2h and h2d copying are
    completely hidden behind computation if computation time of a layer is longer
    than host-device communication time. Bulk offloading with delay and bulk reloading
294
295
296
297
298
    with prefetch are implemented."""

    def __init__(
        self,
        num_offload_group,  # must be <= actual number of groups (number of commits)
299
        num_model_group,
300
301
302
303
304
305
306
307
        tensor_need_offloading_checker=(lambda t: True),
        debug=False,
    ) -> None:
        super().__init__(
            num_offload_group=num_offload_group,
            tensor_need_offloading_checker=tensor_need_offloading_checker,
            debug=debug,
        )
308
309
        # Number of layers in the model
        self.num_layers = num_model_group
310
311
        # Data Structure to maintain reference to activation tensors
        self.tensor_tag_to_buf = {}
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        # Tracking the number of layers offloaded
        self.offloaded_group_count = 0
        # Core data structure that decides the window for offloading
        self.layer_window_map = {}

        # Logic to make offloading load balance across computation
        # for optimal CPU/GPU interconnect usage
        constant = 0
        for i in range(self.num_offload_group):
            self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
            if i < (self.num_layers % self.num_offload_group):
                self.layer_window_map[i] += i + 1
                constant = i + 1
            else:
                self.layer_window_map[i] += constant
327
328
329
330
331
332
333

        # allocate streams and events for synchronization
        self.d2h_stream = torch.cuda.Stream()
        self.h2d_stream = torch.cuda.Stream()

    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:

334
335
336
337
338
339
340
        torch_stray_tensor = isinstance(
            tensor,
            (
                torch._subclasses.fake_tensor.FakeTensor,
                torch._subclasses.functional_tensor.FunctionalTensor,
            ),
        )
341
342
343
344
345
346
347

        if not torch_stray_tensor:
            # obtain a unique tensor tag
            tensor_tag = (self.current_group, self.tensor_count_current_group)
            self.tensor_count_current_group += 1
            assert tensor_tag not in self.tensor_tag_to_state

348
349
            self.tensor_tag_to_state[tensor_tag] = tensor

350
351
352
            if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
                tensor
            ):
353
                self.tensor_tag_to_buf[tensor_tag] = tensor
354
        else:
355
            tensor_tag = (-1, self.torch_tensor_count)
356
            self.torch_tensor_count += 1
357
            self.tensor_tag_to_state[tensor_tag] = tensor
358

359
360
361
362
363
364
        return tensor_tag

    def tensor_pop(self, tensor_tag, **kwargs):
        """Tensor pop."""
        assert tensor_tag in self.tensor_tag_to_state
        tensor = self.tensor_tag_to_state.pop(tensor_tag)
365
        self.tensor_tag_to_buf.pop(tensor_tag, None)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        # the tensor should have been copied back in on_group_commit_backward()
        # which invokes bulk_reload_group.
        assert not isinstance(tensor, tuple)
        return tensor

    def bulk_offload_group(self, group_to_offload):
        """Bulk offload group."""
        with torch.cuda.stream(self.d2h_stream):
            for tensor_tag, state in self.tensor_tag_to_state.items():
                group_id, _ = tensor_tag
                if group_id == group_to_offload:
                    assert not isinstance(state, tuple)
                    tensor_on_device = state

                    # if offload, return the reference to cpu copy
                    if self.tensor_need_offloading_checker(tensor_on_device):
                        state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
                        self.tensor_tag_to_state[tensor_tag] = state
384
                        tensor_on_device.data = torch.Tensor()  # Force to release memory
385
386
387

    def synchronize_on_group_commit_forward(self, current_group):
        """Synchronize on group commit forward."""
388
389
390
391
392

        # For the first group, kickstart the offload after we have
        # the first compute completion
        if current_group == 0:
            self.d2h_stream.wait_stream(torch.cuda.current_stream())
393
            self.bulk_offload_group(current_group)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

        # Window map data structure helps us synchronize based on number
        # of layers offloaded
        if self.layer_window_map[self.offloaded_group_count] == current_group:

            # Stream synchronization both ways
            self.d2h_stream.wait_stream(torch.cuda.current_stream())
            torch.cuda.current_stream().wait_stream(self.d2h_stream)

            # Time to free the activation memory after usage
            for tensor_tag, _ in self.tensor_tag_to_buf.items():
                if tensor_tag[0] == self.offloaded_group_count:
                    self.tensor_tag_to_buf[tensor_tag] = None

            # Time to offload the next group
            if self.offloaded_group_count < (self.num_offload_group - 1):
                self.bulk_offload_group(self.offloaded_group_count + 1)

            # Increment the offload group count to keep track
            self.offloaded_group_count += 1
414
415
416
417
418
419
420
421
422
423
424

    def on_group_commit_forward(self):
        """This function will cause host device synchronization"""
        # handle synchronization events
        self.synchronize_on_group_commit_forward(self.current_group)

        super().on_group_commit_forward()

    def bulk_reload_group(self, group_to_reload):
        """Bulk reload group."""
        assert group_to_reload < self.num_offload_group
425

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        with torch.cuda.stream(self.h2d_stream):
            # move back tensors
            for tensor_label, state in self.tensor_tag_to_state.items():
                group_id, _ = tensor_label
                if group_id == group_to_reload:
                    if isinstance(state, tuple):
                        recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
                        self.tensor_tag_to_state[tensor_label] = recovered_tensor

    def on_group_commit_backward(self):
        # first decrement the current group.
        # after last commit in forward, the group will +1; in backward it -1.
        # Finally it should be decremented to 0.
        self.current_group -= 1
        assert self.current_group >= 0

442
443
        # Layer window data structure helps us to reload at right times
        if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
444

445
446
447
            # Stream synchronization both ways
            self.h2d_stream.wait_stream(torch.cuda.current_stream())
            torch.cuda.current_stream().wait_stream(self.h2d_stream)
448

449
450
            # Time to reload the next group
            self.bulk_reload_group(self.offloaded_group_count - 1)
451

452
453
            # Decrease the offloading group counter
            self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
454

455
456
457
458
        # Last group computation needs to wait till all the reloads complete
        if self.current_group == 0:
            torch.cuda.current_stream().wait_stream(self.h2d_stream)
            self.offloaded_group_count = 0
459
460
461
462
463


def get_cpu_offload_context(
    enabled: bool = False,
    num_layers: int = 1,
464
    model_layers: int = 1,
465
    offload_activations: bool = True,
466
467
    offload_weights: bool = True,
):
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    """
    This function returns the CPU Offload context and the synchronizer function that needs to be
    used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.

    Usage:

    .. code-block:: python

        cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)

        with cpu_offload_context:
            te_layer.forward(inp_tensor)
        cpu_offload_synchronizer()

    Parameters
    ----------
    enabled: bool, default = `False`
             When set to True, CPU Offloading functionality is enabled.
    num_layers: int, default = 1
                Determines the number of transformer layers
                you want to offload activations/weights for.
489
490
    model_layers: int, default = 1
                  Number of layers in the model that will be used under this context.
491
492
493
494
495
496
497
498
    offload_activations: bool, default = `True`
                         When set to `True`, offloads the activations for the TE layer.
    offload_weights: bool, default = `True`
                     When set to `True`, offloads the weights for the TE layer.

    """

    def tensor_need_offloading_checker_activations(tensor):
499
        return hasattr(tensor, "activation_offloading")
500
501
502
503
504

    # This includes the Gradient Accumulation Buffer
    def tensor_need_offloading_checker_weights(tensor):
        return hasattr(tensor, "weight_offloading")

505
    def tensor_need_offloading_checker_all(tensor):
506
        return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading")
507
508
509
510
511
512
513
514
515
516

    if offload_activations and offload_weights:
        tensor_need_offloading_checker = tensor_need_offloading_checker_all
    elif offload_activations:
        tensor_need_offloading_checker = tensor_need_offloading_checker_activations
    elif offload_weights:
        tensor_need_offloading_checker = tensor_need_offloading_checker_weights
    else:
        raise ValueError(
            "CPU Offloading is enabled while it is not "
517
518
            "mentioned what to offload (weights/activations)"
        )
519
520

    cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
521
        num_offload_group=num_layers,
522
        num_model_group=model_layers,
523
524
        tensor_need_offloading_checker=tensor_need_offloading_checker,
    )
525
526

    def group_prefetch_offload_commit_async(tensor):
527
        return group_prefetch_offload_commit(tensor, cpu_offload_handler)
528
529
530
531
532
533
534

    if enabled:
        return (
            CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
            group_prefetch_offload_commit_async,
        )
    return nullcontext(), group_prefetch_offload_commit_async