cpu_offload.py 20.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Functionality for CPU offloading of tensors saved for backward pass."""
from typing import Any
from contextlib import nullcontext
import torch

from .float8_tensor import Float8Tensor

__all__ = ['get_cpu_offload_context']

CPUOffloadEnabled = False


class CpuOffloadSavedTensorHook:
    """Contex-manager that executes a pair of pack/unpack hooks for saved tensors.

    In this context, the ``on_save_for_backward`` method will be called every time
    a tensor is saved for backward (this includes intermediary results saved using
    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
    also those recorded by a PyTorch-defined operation).

    The ``on_get_saved_tensors`` method will be called when the backward function
    of this op attempts to retrieve the saved tensor from context (this includes
    :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
    as input the return value of the ``on_save_for_backward``, and is meant to return
    an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
    size, device and element values.

    Example:

        >>> import torch
        >>> from typing import Any
        >>>
        >>> class DummyHook(CpuOffloadSavedTensorHook):
        ...
        ...     def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        ...         logging.info("On save", tensor)
        ...         return (tensor,)
        ...
        ...     def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        ...         logging.info("On get", saved_state)
        ...         tensor, = saved_state
        ...         return tensor
        ...
        >>> a = torch.ones(5, requires_grad=True)
        >>> b = torch.ones(5, requires_grad=True) * 2
        >>> with DummyHook():
        ...     y = a * b
        ...
        On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
        On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
        >>> y.sum().backward()
        On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
        On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)

    """

    def __init__(self) -> None:
        self.inside_context = False

    def __enter__(self):
        global CPUOffloadEnabled
        CPUOffloadEnabled = True

        self.inside_context = True
        torch._C._autograd._push_saved_tensors_default_hooks(
            self.on_save_for_backward,
            self.on_get_saved_tensor
            )

    def __exit__(self, *args: Any):
        global CPUOffloadEnabled
        CPUOffloadEnabled = False

        self.inside_context = False
        torch._C._autograd._pop_saved_tensors_default_hooks()


    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        """On save for backward."""
        raise NotImplementedError("`on_save_for_backward: Callable[[torch.Tensor], Any]`"
                                  "is not implemented in CpuOffloadHook class. Inherit "
                                  "this class and implement your custom hooks")

    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        """On get saved tensor."""
        raise NotImplementedError("`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
                                  "is not implemented in CpuOffloadHook class. Inherit "
                                  "this class and implement your custom hooks")


class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
    """Context-manager that offloads/recovers tensors through an offload hander.

    The hook just offloads/recovers the tensor object to the handler through `tensor_push`
    and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
    or prefetching timing is transparent to this hook.
    """
    def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value
        self.debug = debug
        self.offload_handler = offload_handler
        self.handler_extra_kwargs = handler_extra_kwargs
        super().__init__()

    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        retrieve_identifier = self.offload_handler.tensor_push(
            tensor,
            **self.handler_extra_kwargs
        )
        return retrieve_identifier

    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        tensor = self.offload_handler.tensor_pop(
            saved_state,
            **self.handler_extra_kwargs
        )
        return tensor


class OffloadHandler:
    """A base class for CPU offload-handler."""
    def __init__(self) -> None:
        pass

    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
        """Tensor push."""
        raise NotImplementedError("`tensor_push is not implented in OffloadHandler class. "
                                  "Inherit this class and implement your custom tensor_push.")

    def tensor_pop(self, tensor_tag: Any, **kwargs):
        """Tensor pop."""
        raise NotImplementedError("`tensor_pop is not implented in OffloadHandler class. "
                                  "Inherit this class and implement your custom tensor_pop.")


class GroupCommitFunction(torch.autograd.Function):
    """this is a dummy op with output identical to input.
    However, it is necessary for marking a timepoint for offload handler to
    accomplish all synchronizations. Implementing it as a function is necessary
    because we need to actions in both forward and backward.
    """
    @staticmethod
    def forward(ctx, tensor, cpu_offload_handler):
        cpu_offload_handler.on_group_commit_forward()
        ctx.cpu_offload_handler = cpu_offload_handler
        # return the identical tensor
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        cpu_offload_handler = ctx.cpu_offload_handler
        cpu_offload_handler.on_group_commit_backward()
        return grad_output, None


group_prefetch_offload_commit = GroupCommitFunction.apply


class SynchronizedGroupOffloadHandler(OffloadHandler):
    """Offload Handler that offloads/reloads in a synchronized way.
    The device-to-host and host-to-device copying happen in the same stream
    as the computation kernels, thus the copying will block computation.
    """
    def __init__(self,
                 num_offload_group,
                 tensor_need_offloading_checker=(lambda _: True),
                 debug=False
                 ) -> None:
        super().__init__()

        self.num_offload_group = num_offload_group
        self.tensor_need_offloading_checker = tensor_need_offloading_checker
        self.debug = debug

        self.groupid_reset()

    def groupid_reset(self):
        """Groupid reset."""
        # Data structures to label saved tensors and book-keep their cpu copies.
        # Currently, on push, create a new cpu tensor and copies; on pop, copies
        # the tensor back to gpu and deletes the cpu tensor.
        # These will increment whenever `group_commit()` is invoked
        self.current_group, self.tensor_count_current_group = (0, 0)
        self.tensor_tag_to_state = {}

    def on_group_commit_forward(self):
        """On group commit forward."""
        # finishing up with updating current group and tensor count
        self.current_group += 1             # increment
        self.tensor_count_current_group = 0 # reset

    def on_group_commit_backward(self):
        """On group commit backward."""
        self.current_group -= 1
        assert self.current_group >= 0

    @staticmethod
    def offload(src_tensor, pin_memory=True):
        """Offload."""
        fp8_offload = isinstance(src_tensor, Float8Tensor)

        cpu_backup = torch.empty(
            src_tensor.size(), dtype=torch.uint8 if fp8_offload else src_tensor.dtype,
            layout=src_tensor.layout, device="cpu", pin_memory=pin_memory)

        if fp8_offload:
            cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup)

        cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
        state = (src_tensor.device, cpu_backup)
        return state

    @staticmethod
    def reload(state, non_blocking=None):
        """Reload."""
        dev, cpu_backup = state
        if non_blocking is None:
            non_blocking = cpu_backup.is_pinned()
        return cpu_backup.to(dev, non_blocking=non_blocking)

    def tensor_push(self, tensor: torch.Tensor, **kwargs):
        """Tensor push."""
        # obtain a unique tensor tag
        tensor_tag = (self.current_group, self.tensor_count_current_group)
        self.tensor_count_current_group += 1
        assert tensor_tag not in self.tensor_tag_to_state
        if (self.current_group < self.num_offload_group
            and self.tensor_need_offloading_checker(tensor)):
            state = SynchronizedGroupOffloadHandler.offload(tensor)
            self.tensor_tag_to_state[tensor_tag] = state
        else:
            # will be offloaded together after group commit
            self.tensor_tag_to_state[tensor_tag] = tensor
        return tensor_tag

    def tensor_pop(self, tensor_tag, **kwargs):
        """Tensor pop."""
        assert tensor_tag in self.tensor_tag_to_state
        state = self.tensor_tag_to_state.pop(tensor_tag)
        if isinstance(state, tuple):
            tensor = SynchronizedGroupOffloadHandler.reload(state)
        else:
            tensor = state
        return tensor


class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
    """Compared to synchronize, this uses more memory because of the buffer but
    achieves better performance due to the overlapping. D2h and h2d copying are
    completely hidden behind computation if computation time of a layer is longer
    than host-device communication time. Bulk offloading with delay and bulk reloading
    with prefetch are implemented. """
    def __init__(self,
                 num_offload_group,     # must be <= actual number of groups (number of commits)
                 num_prefetch_group=1,
                 tensor_need_offloading_checker=(lambda t: True),
                 debug=False
                 ) -> None:
        super().__init__(num_offload_group=num_offload_group,
                         tensor_need_offloading_checker=tensor_need_offloading_checker,
                         debug=debug)
        self.num_prefetch_group = num_prefetch_group

        # prepare for tensor buffer
        self.tensor_id_to_tensor_buf_double_bufs = []
        for _ in range(2):
            self.tensor_id_to_tensor_buf_double_bufs.append({})

        # allocate streams and events for synchronization
        self.d2h_stream = torch.cuda.Stream()
        self.h2d_stream = torch.cuda.Stream()
        self.h2d_finish_events = []
        self.compute_stream_bwd_start_events = []
        for _ in range(self.num_offload_group):
            self.h2d_finish_events.append(torch.cuda.Event())
            self.compute_stream_bwd_start_events.append(torch.cuda.Event())
        self.d2h_final_event = torch.cuda.Event()

    def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag):
        """Get tensor buffer for offloaded tensor."""
        group_id, tensor_id = tensor_tag
        # obtain ping-pong buffer
        id_buf_map = self.tensor_id_to_tensor_buf_double_bufs[(group_id % 2)]

        if not tensor_id in id_buf_map:
            allocate_new_buf = True
        else:
            tensor_buf = id_buf_map[tensor_id]
            if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement
                allocate_new_buf = True
            else:
                allocate_new_buf = False # in this case, reuse the old buffer

        if allocate_new_buf:
            # supposed to only execute once
            fp8_offload = isinstance(tensor, Float8Tensor)
            buffer = torch.empty(
                tensor.size(), dtype=torch.uint8 if fp8_offload else tensor.dtype,
                layout=tensor.layout, device=tensor.device)

            if isinstance(tensor, Float8Tensor):
                id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer)
            else:
                id_buf_map[tensor_id] = buffer

        return id_buf_map[tensor_id]


    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
        # obtain a unique tensor tag
        tensor_tag = (self.current_group, self.tensor_count_current_group)
        self.tensor_count_current_group += 1
        assert tensor_tag not in self.tensor_tag_to_state

        if (self.current_group < self.num_offload_group
            and self.tensor_need_offloading_checker(tensor)):
            # first copy the tensor to tensorbuf, so that the original tensor will not be deleted
            tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
            tensor_buf.copy_(tensor)
            if hasattr(tensor,"weight_offloading"):
                tensor_buf.weight_offloading = True
            if hasattr(tensor,"activation_offloading"):
                tensor_buf.activation_offloading = True
           # Here we just save it, and at commit, bulk_offload_group will handle it
            self.tensor_tag_to_state[tensor_tag] = tensor_buf
        else:
            self.tensor_tag_to_state[tensor_tag] = tensor
        return tensor_tag

    def tensor_pop(self, tensor_tag, **kwargs):
        """Tensor pop."""
        assert tensor_tag in self.tensor_tag_to_state
        tensor = self.tensor_tag_to_state.pop(tensor_tag)
        # the tensor should have been copied back in on_group_commit_backward()
        # which invokes bulk_reload_group.
        assert not isinstance(tensor, tuple)
        return tensor

    def bulk_offload_group(self, group_to_offload):
        """Bulk offload group."""
        with torch.cuda.stream(self.d2h_stream):
            for tensor_tag, state in self.tensor_tag_to_state.items():
                group_id, _ = tensor_tag
                if group_id == group_to_offload:
                    assert not isinstance(state, tuple)
                    tensor_on_device = state

                    # if offload, return the reference to cpu copy
                    if self.tensor_need_offloading_checker(tensor_on_device):
                        state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
                        self.tensor_tag_to_state[tensor_tag] = state

    def synchronize_on_group_commit_forward(self, current_group):
        """Synchronize on group commit forward."""
        # the host should wait for the copying of previous group
        # to avoid overwriting buffer
        previous_group = current_group - 1
        if previous_group < self.num_offload_group:
            torch.cuda.synchronize()
            # TODO (guyueh): this part is originally designed to reduce the peak memory usage. # pylint: disable=fixme
            # however, uncommenting this part will cause illegal access, have not figured out why.

            if previous_group + 2 >= self.num_offload_group:
                # this buffer is no longer required
                self.tensor_id_to_tensor_buf_double_bufs[(previous_group % 2)] = {}

        # the copying of this group should wait for the computation stream event
        if current_group < self.num_offload_group:
            # perform bulk offloading
            self.bulk_offload_group(current_group)
            if current_group == self.num_offload_group - 1:
                self.d2h_stream.record_event(self.d2h_final_event)

    def on_group_commit_forward(self):
        """This function will cause host device synchronization"""
        # handle synchronization events
        self.synchronize_on_group_commit_forward(self.current_group)

        # during forward, the next_group_to_fetch always points to the min of
        # the last commited group, and the last offloaded group
        self.next_group_to_fetch = min(self.current_group, self.num_offload_group -1)

        super().on_group_commit_forward()

    def bulk_reload_group(self, group_to_reload):
        """Bulk reload group."""
        assert group_to_reload < self.num_offload_group
        if group_to_reload == self.num_offload_group - 1:
            self.h2d_stream.wait_event(self.d2h_final_event)
        with torch.cuda.stream(self.h2d_stream):
            # move back tensors
            for tensor_label, state in self.tensor_tag_to_state.items():
                group_id, _ = tensor_label
                if group_id == group_to_reload:
                    if isinstance(state, tuple):
                        recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
                        self.tensor_tag_to_state[tensor_label] = recovered_tensor

    def on_group_commit_backward(self):
        # first decrement the current group.
        # after last commit in forward, the group will +1; in backward it -1.
        # Finally it should be decremented to 0.
        self.current_group -= 1
        assert self.current_group >= 0

        # decide the range of group to prefetch
        should_prefetch_until_group = self.current_group - self.num_prefetch_group
        should_prefetch_until_group = max(should_prefetch_until_group, 0)

        # do prefetch
        for group_num_to_prefetch in range(
            self.next_group_to_fetch, should_prefetch_until_group - 1, -1
        ):
            # record the event in the compute stream, for h2d to wait
            torch.cuda.current_stream().record_event(
                self.compute_stream_bwd_start_events[group_num_to_prefetch])

            # start of h2d should wait for the compute and the d2h
            self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch])

            #recover tensors (copy back from host)
            self.bulk_reload_group(group_num_to_prefetch)

            # record an event for the backward of this layer to wait
            self.h2d_stream.record_event(self.h2d_finish_events[group_num_to_prefetch])

        # always is set to -1 at the end of the backward
        self.next_group_to_fetch = min(self.num_offload_group - 1, should_prefetch_until_group - 1)

        # wait for the current group
        if self.current_group < self.num_offload_group:
            torch.cuda.current_stream().wait_event(self.h2d_finish_events[self.current_group])


def get_cpu_offload_context(
    enabled: bool = False,
    num_layers: int = 1,
    offload_activations: bool = True,
    offload_weights: bool = True):
    """
    This function returns the CPU Offload context and the synchronizer function that needs to be
    used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.

    Usage:

    .. code-block:: python

        cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)

        with cpu_offload_context:
            te_layer.forward(inp_tensor)
        cpu_offload_synchronizer()

    Parameters
    ----------
    enabled: bool, default = `False`
             When set to True, CPU Offloading functionality is enabled.
    num_layers: int, default = 1
                Determines the number of transformer layers
                you want to offload activations/weights for.
    offload_activations: bool, default = `True`
                         When set to `True`, offloads the activations for the TE layer.
    offload_weights: bool, default = `True`
                     When set to `True`, offloads the weights for the TE layer.

    """

    def tensor_need_offloading_checker_activations(tensor):
        return hasattr(tensor,"activation_offloading")

    # This includes the Gradient Accumulation Buffer
    def tensor_need_offloading_checker_weights(tensor):
        return hasattr(tensor, "weight_offloading")

    def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument
        return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading"))

    if offload_activations and offload_weights:
        tensor_need_offloading_checker = tensor_need_offloading_checker_all
    elif offload_activations:
        tensor_need_offloading_checker = tensor_need_offloading_checker_activations
    elif offload_weights:
        tensor_need_offloading_checker = tensor_need_offloading_checker_weights
    else:
        raise ValueError(
            "CPU Offloading is enabled while it is not "
            "mentioned what to offload (weights/activations)")

    cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
                          num_offload_group=num_layers,
                          num_prefetch_group=1,
                          tensor_need_offloading_checker=tensor_need_offloading_checker
                          )

    def group_prefetch_offload_commit_async(tensor):
        return group_prefetch_offload_commit(tensor,cpu_offload_handler)

    if enabled:
        return (
            CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
            group_prefetch_offload_commit_async,
        )
    return nullcontext(), group_prefetch_offload_commit_async