cpu_offload.py 24.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Functionality for CPU offloading of tensors saved for backward pass."""
6
from __future__ import annotations
7
from contextlib import nullcontext
8
9
from typing import Any, Dict, Optional

10
11
import torch

12
from .tensor.float8_tensor import Float8Tensor
13

14
__all__ = ["get_cpu_offload_context"]
15
16
17
18

CPUOffloadEnabled = False


19
def mark_activation_offload(*tensors):
20
    """Set the type of the offloading needed for a tensor."""
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    for tensor in tensors:
        if tensor is None:
            continue
        if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
            tensor.activation_offloading = True
        else:
            data_tensors = tensor.get_data_tensors()
            for tensor in data_tensors:
                if tensor is not None:
                    tensor.activation_offloading = True
                    # This is a hack to force clear the tensor after it is offloaded.
                    # It is needed, because .*TensorBase classes are saved in the ctx,
                    # and they contain the reference to their data tensors.
                    tensor.needs_force_clear = True
35
36


37
38
39
40
41
def is_cpu_offload_enabled() -> bool:
    """Check if CPU offloading is currently enabled."""
    return CPUOffloadEnabled


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class CpuOffloadSavedTensorHook:
    """Contex-manager that executes a pair of pack/unpack hooks for saved tensors.

    In this context, the ``on_save_for_backward`` method will be called every time
    a tensor is saved for backward (this includes intermediary results saved using
    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
    also those recorded by a PyTorch-defined operation).

    The ``on_get_saved_tensors`` method will be called when the backward function
    of this op attempts to retrieve the saved tensor from context (this includes
    :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
    as input the return value of the ``on_save_for_backward``, and is meant to return
    an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
    size, device and element values.

    Example:

        >>> import torch
        >>> from typing import Any
        >>>
        >>> class DummyHook(CpuOffloadSavedTensorHook):
        ...
        ...     def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        ...         logging.info("On save", tensor)
        ...         return (tensor,)
        ...
        ...     def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        ...         logging.info("On get", saved_state)
        ...         tensor, = saved_state
        ...         return tensor
        ...
        >>> a = torch.ones(5, requires_grad=True)
        >>> b = torch.ones(5, requires_grad=True) * 2
        >>> with DummyHook():
        ...     y = a * b
        ...
        On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
        On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
        >>> y.sum().backward()
        On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
        On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)

    """

    def __init__(self) -> None:
        self.inside_context = False

    def __enter__(self):
        global CPUOffloadEnabled
        CPUOffloadEnabled = True

        self.inside_context = True
        torch._C._autograd._push_saved_tensors_default_hooks(
95
96
            self.on_save_for_backward, self.on_get_saved_tensor
        )
97
98
99
100
101
102
103
104
105
106

    def __exit__(self, *args: Any):
        global CPUOffloadEnabled
        CPUOffloadEnabled = False

        self.inside_context = False
        torch._C._autograd._pop_saved_tensors_default_hooks()

    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
        """On save for backward."""
107
108
109
110
111
        raise NotImplementedError(
            "`on_save_for_backward: Callable[[torch.Tensor], Any]`"
            "is not implemented in CpuOffloadHook class. Inherit "
            "this class and implement your custom hooks"
        )
112
113
114

    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
        """On get saved tensor."""
115
116
117
118
119
        raise NotImplementedError(
            "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
            "is not implemented in CpuOffloadHook class. Inherit "
            "this class and implement your custom hooks"
        )
120
121
122
123
124
125
126
127
128


class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
    """Context-manager that offloads/recovers tensors through an offload hander.

    The hook just offloads/recovers the tensor object to the handler through `tensor_push`
    and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
    or prefetching timing is transparent to this hook.
    """
129

130
131
132
    def __init__(
        self,
        offload_handler: OffloadHandler,
133
        handler_extra_kwargs: Optional[Dict[str, Any]] = None,
134
135
136
137
138
139
        debug: bool = False,
    ) -> None:
        if handler_extra_kwargs is None:
            handler_extra_kwargs = {}
        self.debug: bool = debug
        self.offload_handler: OffloadHandler = offload_handler
140
        self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
141
142
143
        super().__init__()

    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
144
        retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
145
146
147
        return retrieve_identifier

    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
148
        tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
149
150
151
152
153
        return tensor


class OffloadHandler:
    """A base class for CPU offload-handler."""
154

155
156
157
158
159
    def __init__(self) -> None:
        pass

    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
        """Tensor push."""
160
161
162
163
        raise NotImplementedError(
            "`tensor_push is not implented in OffloadHandler class. "
            "Inherit this class and implement your custom tensor_push."
        )
164
165
166

    def tensor_pop(self, tensor_tag: Any, **kwargs):
        """Tensor pop."""
167
168
169
170
        raise NotImplementedError(
            "`tensor_pop is not implented in OffloadHandler class. "
            "Inherit this class and implement your custom tensor_pop."
        )
171
172
173
174
175
176
177
178


class GroupCommitFunction(torch.autograd.Function):
    """this is a dummy op with output identical to input.
    However, it is necessary for marking a timepoint for offload handler to
    accomplish all synchronizations. Implementing it as a function is necessary
    because we need to actions in both forward and backward.
    """
179

180
181
    @staticmethod
    def forward(ctx, tensor, cpu_offload_handler):
182
        # pylint: disable=missing-function-docstring
183
184
185
186
187
188
189
        cpu_offload_handler.on_group_commit_forward()
        ctx.cpu_offload_handler = cpu_offload_handler
        # return the identical tensor
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
190
        # pylint: disable=missing-function-docstring
191
192
193
194
195
196
197
198
199
200
201
202
203
        cpu_offload_handler = ctx.cpu_offload_handler
        cpu_offload_handler.on_group_commit_backward()
        return grad_output, None


group_prefetch_offload_commit = GroupCommitFunction.apply


class SynchronizedGroupOffloadHandler(OffloadHandler):
    """Offload Handler that offloads/reloads in a synchronized way.
    The device-to-host and host-to-device copying happen in the same stream
    as the computation kernels, thus the copying will block computation.
    """
204
205
206
207

    def __init__(
        self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False
    ) -> None:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        super().__init__()

        self.num_offload_group = num_offload_group
        self.tensor_need_offloading_checker = tensor_need_offloading_checker
        self.debug = debug

        self.groupid_reset()

    def groupid_reset(self):
        """Groupid reset."""
        # Data structures to label saved tensors and book-keep their cpu copies.
        # Currently, on push, create a new cpu tensor and copies; on pop, copies
        # the tensor back to gpu and deletes the cpu tensor.
        # These will increment whenever `group_commit()` is invoked
        self.current_group, self.tensor_count_current_group = (0, 0)
223
        self.torch_tensor_count = 0
224
225
226
227
228
        self.tensor_tag_to_state = {}

    def on_group_commit_forward(self):
        """On group commit forward."""
        # finishing up with updating current group and tensor count
229
230
        self.current_group += 1  # increment
        self.tensor_count_current_group = 0  # reset
231
232
233
234
235
236
237
238
239
240
241

    def on_group_commit_backward(self):
        """On group commit backward."""
        self.current_group -= 1
        assert self.current_group >= 0

    @staticmethod
    def offload(src_tensor, pin_memory=True):
        """Offload."""

        cpu_backup = torch.empty(
242
            src_tensor.size(),
243
            dtype=src_tensor.dtype,
244
245
246
247
            layout=src_tensor.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

        cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
        state = (src_tensor.device, cpu_backup)
        return state

    @staticmethod
    def reload(state, non_blocking=None):
        """Reload."""
        dev, cpu_backup = state
        if non_blocking is None:
            non_blocking = cpu_backup.is_pinned()
        return cpu_backup.to(dev, non_blocking=non_blocking)

    def tensor_push(self, tensor: torch.Tensor, **kwargs):
        """Tensor push."""
        # obtain a unique tensor tag
        tensor_tag = (self.current_group, self.tensor_count_current_group)
        self.tensor_count_current_group += 1
        assert tensor_tag not in self.tensor_tag_to_state
267
268
269
        if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
            tensor
        ):
270
271
272
273
274
            state = SynchronizedGroupOffloadHandler.offload(tensor)
            self.tensor_tag_to_state[tensor_tag] = state
        else:
            # will be offloaded together after group commit
            self.tensor_tag_to_state[tensor_tag] = tensor
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        return tensor_tag

    def tensor_pop(self, tensor_tag, **kwargs):
        """Tensor pop."""
        assert tensor_tag in self.tensor_tag_to_state
        state = self.tensor_tag_to_state.pop(tensor_tag)
        if isinstance(state, tuple):
            tensor = SynchronizedGroupOffloadHandler.reload(state)
        else:
            tensor = state
        return tensor


class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
    """Compared to synchronize, this uses more memory because of the buffer but
    achieves better performance due to the overlapping. D2h and h2d copying are
    completely hidden behind computation if computation time of a layer is longer
    than host-device communication time. Bulk offloading with delay and bulk reloading
294
295
296
297
298
    with prefetch are implemented."""

    def __init__(
        self,
        num_offload_group,  # must be <= actual number of groups (number of commits)
299
        num_model_group,
300
301
302
303
304
305
306
307
        tensor_need_offloading_checker=(lambda t: True),
        debug=False,
    ) -> None:
        super().__init__(
            num_offload_group=num_offload_group,
            tensor_need_offloading_checker=tensor_need_offloading_checker,
            debug=debug,
        )
308
309
        # Number of layers in the model
        self.num_layers = num_model_group
310
311
        # Data Structure to maintain reference to activation tensors
        self.tensor_tag_to_buf = {}
312
313
314
        # Data structure to hold the FP8/MXFP8 tensor objects
        self.fp8_tensor_object_map = {}
        self.float8_transpose_cache_valid = {}
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        # Tracking the number of layers offloaded
        self.offloaded_group_count = 0
        # Core data structure that decides the window for offloading
        self.layer_window_map = {}

        # Logic to make offloading load balance across computation
        # for optimal CPU/GPU interconnect usage
        constant = 0
        for i in range(self.num_offload_group):
            self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
            if i < (self.num_layers % self.num_offload_group):
                self.layer_window_map[i] += i + 1
                constant = i + 1
            else:
                self.layer_window_map[i] += constant
330
331
332
333
334
335
336

        # allocate streams and events for synchronization
        self.d2h_stream = torch.cuda.Stream()
        self.h2d_stream = torch.cuda.Stream()

    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:

337
338
339
340
341
342
343
        torch_stray_tensor = isinstance(
            tensor,
            (
                torch._subclasses.fake_tensor.FakeTensor,
                torch._subclasses.functional_tensor.FunctionalTensor,
            ),
        )
344

345
346
        is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None))

347
        if not torch_stray_tensor:
348

349
350
351
            # obtain a unique tensor tag
            tensor_tag = (self.current_group, self.tensor_count_current_group)
            self.tensor_count_current_group += 1
352

353
354
            assert tensor_tag not in self.tensor_tag_to_state

355
356
357
358
359
            if is_quantized_tensor:
                tensor_list, _ = tensor.prepare_for_saving()

                self.tensor_tag_to_state[tensor_tag] = []
                self.tensor_tag_to_buf[tensor_tag] = []
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
                self.fp8_tensor_object_map[tensor_tag] = tensor
                if isinstance(tensor, Float8Tensor):
                    self.float8_transpose_cache_valid[tensor_tag] = getattr(
                        tensor, "_transpose_invalid"
                    )
            else:
                tensor_list = [tensor]

            for t in tensor_list:
                if is_quantized_tensor:
                    self.tensor_tag_to_state[tensor_tag].append(t)
                else:
                    self.tensor_tag_to_state[tensor_tag] = t

                if (
                    self.current_group < self.num_offload_group
                    and self.tensor_need_offloading_checker(t)
                ):
                    if is_quantized_tensor:
                        self.tensor_tag_to_buf[tensor_tag].append(t)
                        # Need to clear the internal data reference for the quantized tensors
                        tensor.clear()
                    else:
                        self.tensor_tag_to_buf[tensor_tag] = t
385
        else:
386
            tensor_tag = (-1, self.torch_tensor_count)
387
            self.torch_tensor_count += 1
388
            self.tensor_tag_to_state[tensor_tag] = tensor
389

390
391
392
393
394
395
        return tensor_tag

    def tensor_pop(self, tensor_tag, **kwargs):
        """Tensor pop."""
        assert tensor_tag in self.tensor_tag_to_state
        tensor = self.tensor_tag_to_state.pop(tensor_tag)
396
397
398
399
400
401

        # Handling the quantized tensor case specially here
        if isinstance(tensor, list):
            self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
            tensor = self.fp8_tensor_object_map.pop(tensor_tag)

402
        self.tensor_tag_to_buf.pop(tensor_tag, None)
403

404
405
406
407
408
409
410
411
412
413
414
415
416
        # the tensor should have been copied back in on_group_commit_backward()
        # which invokes bulk_reload_group.
        assert not isinstance(tensor, tuple)
        return tensor

    def bulk_offload_group(self, group_to_offload):
        """Bulk offload group."""
        with torch.cuda.stream(self.d2h_stream):
            for tensor_tag, state in self.tensor_tag_to_state.items():
                group_id, _ = tensor_tag
                if group_id == group_to_offload:
                    assert not isinstance(state, tuple)

417
418
419
420
421
422
423
424
425
                    is_quantized_tensor = isinstance(state, list)

                    if is_quantized_tensor:
                        tensor_list = state
                        self.tensor_tag_to_state[tensor_tag] = []
                    else:
                        tensor_list = [state]

                    for tensor_on_device in tensor_list:
426
427
428
429
430
431
432
433
434
435
                        # `tensor_offloaded` is a hacky way of dealing with columnwise-only
                        # quantized tensors for CPU offloading. The complication is due to
                        # the `rowwise_data` being `None`. The offloading checker incorrectly
                        # returns `False` and the entire `state` ([None, columnwise_tensor])
                        # is added to the tensor tag state dict. A better design would change
                        # how quantized tensors are kept track of in the offload handler.
                        # Currently at every stage it is ensured that a quantized tensor is a
                        # list whereas a non-quantized tensor is standalone object, which is
                        # not good! TODO(@sanandaraj5597)
                        tensor_offloaded = False
436
437
                        # if offload, return the reference to cpu copy
                        if self.tensor_need_offloading_checker(tensor_on_device):
438
                            tensor_offloaded = True
439
440
                            state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
                        if is_quantized_tensor:
441
442
443
444
                            if tensor_offloaded:
                                self.tensor_tag_to_state[tensor_tag].append(state)
                            else:
                                self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
445
446
                        else:
                            self.tensor_tag_to_state[tensor_tag] = state
447
448
449

    def synchronize_on_group_commit_forward(self, current_group):
        """Synchronize on group commit forward."""
450
451
452
453
454

        # For the first group, kickstart the offload after we have
        # the first compute completion
        if current_group == 0:
            self.d2h_stream.wait_stream(torch.cuda.current_stream())
455
            self.bulk_offload_group(current_group)
456
457
458
459
460
461
462
463
464
465

        # Window map data structure helps us synchronize based on number
        # of layers offloaded
        if self.layer_window_map[self.offloaded_group_count] == current_group:

            # Stream synchronization both ways
            self.d2h_stream.wait_stream(torch.cuda.current_stream())
            torch.cuda.current_stream().wait_stream(self.d2h_stream)

            # Time to free the activation memory after usage
466
            for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
467
                if tensor_tag[0] == self.offloaded_group_count:
468
469
470
471
472
473
474
                    if hasattr(tensor_buf, "needs_force_clear"):
                        # Need to clear activation tensor - sometimes references persist in the code.
                        # This is the case for example with the Float8TensorBase class,
                        # which is saved directly inside the ctx while its internal tensors are
                        # saved inside save_for_backward.
                        tensor_buf.data = torch.Tensor()
                    # Release the pointer to the tensor
475
476
477
478
479
480
481
482
                    self.tensor_tag_to_buf[tensor_tag] = None

            # Time to offload the next group
            if self.offloaded_group_count < (self.num_offload_group - 1):
                self.bulk_offload_group(self.offloaded_group_count + 1)

            # Increment the offload group count to keep track
            self.offloaded_group_count += 1
483
484
485
486
487
488
489
490
491
492
493

    def on_group_commit_forward(self):
        """This function will cause host device synchronization"""
        # handle synchronization events
        self.synchronize_on_group_commit_forward(self.current_group)

        super().on_group_commit_forward()

    def bulk_reload_group(self, group_to_reload):
        """Bulk reload group."""
        assert group_to_reload < self.num_offload_group
494

495
496
497
498
499
500
501
502
        with torch.cuda.stream(self.h2d_stream):
            # move back tensors
            for tensor_label, state in self.tensor_tag_to_state.items():
                group_id, _ = tensor_label
                if group_id == group_to_reload:
                    if isinstance(state, tuple):
                        recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
                        self.tensor_tag_to_state[tensor_label] = recovered_tensor
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                    elif isinstance(state, list):
                        tensor_list = []
                        for state_tuple in state:
                            if isinstance(state_tuple, tuple):
                                tensor_list.append(
                                    SynchronizedGroupOffloadHandler.reload(state_tuple)
                                )
                            else:
                                tensor_list.append(state_tuple)
                        _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list)
                        if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
                            self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
                                self.float8_transpose_cache_valid.pop(tensor_label)
                            )
                        self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
                            tensor_label
                        )
520
521
522
523
524
525
526
527

    def on_group_commit_backward(self):
        # first decrement the current group.
        # after last commit in forward, the group will +1; in backward it -1.
        # Finally it should be decremented to 0.
        self.current_group -= 1
        assert self.current_group >= 0

528
529
        # Layer window data structure helps us to reload at right times
        if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
530

531
532
533
            # Stream synchronization both ways
            self.h2d_stream.wait_stream(torch.cuda.current_stream())
            torch.cuda.current_stream().wait_stream(self.h2d_stream)
534

535
536
            # Time to reload the next group
            self.bulk_reload_group(self.offloaded_group_count - 1)
537

538
539
            # Decrease the offloading group counter
            self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
540

541
542
543
544
        # Last group computation needs to wait till all the reloads complete
        if self.current_group == 0:
            torch.cuda.current_stream().wait_stream(self.h2d_stream)
            self.offloaded_group_count = 0
545
546
547
548
549


def get_cpu_offload_context(
    enabled: bool = False,
    num_layers: int = 1,
550
    model_layers: int = 1,
551
    offload_activations: bool = True,
552
    offload_weights: bool = False,
553
):
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    """
    This function returns the CPU Offload context and the synchronizer function that needs to be
    used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.

    Usage:

    .. code-block:: python

        cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)

        with cpu_offload_context:
            te_layer.forward(inp_tensor)
        cpu_offload_synchronizer()

    Parameters
    ----------
    enabled: bool, default = `False`
             When set to True, CPU Offloading functionality is enabled.
    num_layers: int, default = 1
                Determines the number of transformer layers
                you want to offload activations/weights for.
575
576
    model_layers: int, default = 1
                  Number of layers in the model that will be used under this context.
577
578
579
580
581
582
583
    offload_activations: bool, default = `True`
                         When set to `True`, offloads the activations for the TE layer.
    offload_weights: bool, default = `True`
                     When set to `True`, offloads the weights for the TE layer.

    """

584
    if not offload_weights and not offload_activations:
585
586
        raise ValueError(
            "CPU Offloading is enabled while it is not "
587
588
            "mentioned what to offload (weights/activations)"
        )
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    if offload_weights:
        import warnings

        warnings.warn(
            "Offloading weights is deprecated. Using offload_weights=True does not have any"
            " effect.",
            DeprecationWarning,
        )

        # Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
        if not offload_activations:
            return nullcontext(), lambda x: x

    def tensor_need_offloading_checker_activations(tensor):
        return hasattr(tensor, "activation_offloading")

    tensor_need_offloading_checker = tensor_need_offloading_checker_activations

608
    cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
609
        num_offload_group=num_layers,
610
        num_model_group=model_layers,
611
612
        tensor_need_offloading_checker=tensor_need_offloading_checker,
    )
613
614

    def group_prefetch_offload_commit_async(tensor):
615
        return group_prefetch_offload_commit(tensor, cpu_offload_handler)
616
617
618
619
620
621
622

    if enabled:
        return (
            CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
            group_prefetch_offload_commit_async,
        )
    return nullcontext(), group_prefetch_offload_commit_async