group_offloading.py 45 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
Aryan's avatar
Aryan committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import hashlib
16
import os
17
from contextlib import contextmanager, nullcontext
18
from dataclasses import dataclass, replace
19
from enum import Enum
20
from typing import Dict, List, Optional, Set, Tuple, Union
Aryan's avatar
Aryan committed
21

22
import safetensors.torch
Aryan's avatar
Aryan committed
23
24
25
import torch

from ..utils import get_logger, is_accelerate_available
26
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
Aryan's avatar
Aryan committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from .hooks import HookRegistry, ModelHook


if is_accelerate_available():
    from accelerate.hooks import AlignDevicesHook, CpuOffload
    from accelerate.utils import send_to_device


logger = get_logger(__name__)  # pylint: disable=invalid-name


# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
42
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
Aryan's avatar
Aryan committed
43
44
45
# fmt: on


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class GroupOffloadingType(str, Enum):
    BLOCK_LEVEL = "block_level"
    LEAF_LEVEL = "leaf_level"


@dataclass
class GroupOffloadingConfig:
    onload_device: torch.device
    offload_device: torch.device
    offload_type: GroupOffloadingType
    non_blocking: bool
    record_stream: bool
    low_cpu_mem_usage: bool
    num_blocks_per_group: Optional[int] = None
    offload_to_disk_path: Optional[str] = None
    stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
62
63
64
    block_modules: Optional[List[str]] = None
    exclude_kwargs: Optional[List[str]] = None
    module_prefix: Optional[str] = ""
65
66


Aryan's avatar
Aryan committed
67
68
69
70
71
72
73
74
75
76
77
class ModuleGroup:
    def __init__(
        self,
        modules: List[torch.nn.Module],
        offload_device: torch.device,
        onload_device: torch.device,
        offload_leader: torch.nn.Module,
        onload_leader: Optional[torch.nn.Module] = None,
        parameters: Optional[List[torch.nn.Parameter]] = None,
        buffers: Optional[List[torch.Tensor]] = None,
        non_blocking: bool = False,
78
        stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
79
        record_stream: Optional[bool] = False,
80
        low_cpu_mem_usage: bool = False,
Aryan's avatar
Aryan committed
81
        onload_self: bool = True,
82
        offload_to_disk_path: Optional[str] = None,
83
        group_id: Optional[Union[int, str]] = None,
Aryan's avatar
Aryan committed
84
85
86
87
88
89
    ) -> None:
        self.modules = modules
        self.offload_device = offload_device
        self.onload_device = onload_device
        self.offload_leader = offload_leader
        self.onload_leader = onload_leader
90
91
        self.parameters = parameters or []
        self.buffers = buffers or []
Aryan's avatar
Aryan committed
92
93
        self.non_blocking = non_blocking or stream is not None
        self.stream = stream
94
        self.record_stream = record_stream
Aryan's avatar
Aryan committed
95
        self.onload_self = onload_self
96
        self.low_cpu_mem_usage = low_cpu_mem_usage
97
98
99
100

        self.offload_to_disk_path = offload_to_disk_path
        self._is_offloaded_to_disk = False

101
        if self.offload_to_disk_path is not None:
102
103
104
105
            # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
            self.group_id = group_id if group_id is not None else str(id(self))
            short_hash = _compute_group_hash(self.group_id)
            self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
106
107
108
109
110
111
112
113
114
115
116
117
118
119

            all_tensors = []
            for module in self.modules:
                all_tensors.extend(list(module.parameters()))
                all_tensors.extend(list(module.buffers()))
            all_tensors.extend(self.parameters)
            all_tensors.extend(self.buffers)
            all_tensors = list(dict.fromkeys(all_tensors))  # Remove duplicates

            self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
            self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
            self.cpu_param_dict = {}
        else:
            self.cpu_param_dict = self._init_cpu_param_dict()
120

121
122
123
124
125
126
        self._torch_accelerator_module = (
            getattr(torch, torch.accelerator.current_accelerator().type)
            if hasattr(torch, "accelerator")
            else torch.cuda
        )

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    def _init_cpu_param_dict(self):
        cpu_param_dict = {}
        if self.stream is None:
            return cpu_param_dict

        for module in self.modules:
            for param in module.parameters():
                cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
            for buffer in module.buffers():
                cpu_param_dict[buffer] = (
                    buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
                )

        for param in self.parameters:
            cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()

        for buffer in self.buffers:
            cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()

        return cpu_param_dict

    @contextmanager
    def _pinned_memory_tensors(self):
        try:
151
152
153
154
            pinned_dict = {
                param: tensor.pin_memory() if not tensor.is_pinned() else tensor
                for param, tensor in self.cpu_param_dict.items()
            }
155
156
157
            yield pinned_dict
        finally:
            pinned_dict = None
Aryan's avatar
Aryan committed
158

159
    def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
160
        tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
161
        if self.record_stream:
162
            tensor.data.record_stream(default_stream)
163

164
    def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
165
166
167
        for group_module in self.modules:
            for param in group_module.parameters():
                source = pinned_memory[param] if pinned_memory else param.data
168
                self._transfer_tensor_to_device(param, source, default_stream)
169
170
            for buffer in group_module.buffers():
                source = pinned_memory[buffer] if pinned_memory else buffer.data
171
                self._transfer_tensor_to_device(buffer, source, default_stream)
172
173
174

        for param in self.parameters:
            source = pinned_memory[param] if pinned_memory else param.data
175
            self._transfer_tensor_to_device(param, source, default_stream)
176
177
178

        for buffer in self.buffers:
            source = pinned_memory[buffer] if pinned_memory else buffer.data
179
            self._transfer_tensor_to_device(buffer, source, default_stream)
180

181
    def _onload_from_disk(self):
182
        if self.stream is not None:
183
184
            # Wait for previous Host->Device transfer to complete
            self.stream.synchronize()
185

186
187
        context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
        current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
188

189
190
191
192
        with context:
            # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
            device = str(self.onload_device) if self.stream is None else "cpu"
            loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
193

194
            if self.stream is not None:
195
196
197
198
199
200
201
202
203
204
205
206
                for key, tensor_obj in self.key_to_tensor.items():
                    pinned_tensor = loaded_tensors[key].pin_memory()
                    tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
                    if self.record_stream:
                        tensor_obj.data.record_stream(current_stream)
            else:
                onload_device = (
                    self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
                )
                loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
                for key, tensor_obj in self.key_to_tensor.items():
                    tensor_obj.data = loaded_tensors[key]
207

208
    def _onload_from_memory(self):
Aryan's avatar
Aryan committed
209
210
211
212
        if self.stream is not None:
            # Wait for previous Host->Device transfer to complete
            self.stream.synchronize()

213
        context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
214
215
        default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None

Aryan's avatar
Aryan committed
216
        with context:
217
218
            if self.stream is not None:
                with self._pinned_memory_tensors() as pinned_memory:
219
                    self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
220
            else:
221
                self._process_tensors_from_modules(None)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

    def _offload_to_disk(self):
        # TODO: we can potentially optimize this code path by checking if the _all_ the desired
        # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
        # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
        # we perform a write.
        # Check if the file has been saved in this session or if it already exists on disk.
        if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
            os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
            tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
            safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)

        # The group is now considered offloaded to disk for the rest of the session.
        self._is_offloaded_to_disk = True

        # We do this to free up the RAM which is still holding the up tensor data.
        for tensor_obj in self.tensor_to_key.keys():
            tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

    def _offload_to_memory(self):
Aryan's avatar
Aryan committed
242
        if self.stream is not None:
243
            if not self.record_stream:
244
245
                self._torch_accelerator_module.current_stream().synchronize()

Aryan's avatar
Aryan committed
246
247
248
            for group_module in self.modules:
                for param in group_module.parameters():
                    param.data = self.cpu_param_dict[param]
249
250
251
252
            for param in self.parameters:
                param.data = self.cpu_param_dict[param]
            for buffer in self.buffers:
                buffer.data = self.cpu_param_dict[buffer]
Aryan's avatar
Aryan committed
253
254
        else:
            for group_module in self.modules:
255
                group_module.to(self.offload_device, non_blocking=False)
256
            for param in self.parameters:
257
                param.data = param.data.to(self.offload_device, non_blocking=False)
258
            for buffer in self.buffers:
259
260
261
262
263
264
265
266
267
                buffer.data = buffer.data.to(self.offload_device, non_blocking=False)

    @torch.compiler.disable()
    def onload_(self):
        r"""Onloads the group of parameters to the onload_device."""
        if self.offload_to_disk_path is not None:
            self._onload_from_disk()
        else:
            self._onload_from_memory()
Aryan's avatar
Aryan committed
268

269
270
    @torch.compiler.disable()
    def offload_(self):
271
        r"""Offloads the group of parameters to the offload_device."""
272
273
274
275
276
        if self.offload_to_disk_path:
            self._offload_to_disk()
        else:
            self._offload_to_memory()

Aryan's avatar
Aryan committed
277
278
279
280
281
282
283
284
285
286
287

class GroupOffloadingHook(ModelHook):
    r"""
    A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
    computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
    module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
    group is responsible for onloading the current module group.
    """

    _is_stateful = False

288
    def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
Aryan's avatar
Aryan committed
289
        self.group = group
290
        self.next_group: Optional[ModuleGroup] = None
291
        self.config = config
Aryan's avatar
Aryan committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

    def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
        if self.group.offload_leader == module:
            self.group.offload_()
        return module

    def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
        # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
        # method is the onload_leader of the group.
        if self.group.onload_leader is None:
            self.group.onload_leader = module

        # If the current module is the onload_leader of the group, we onload the group if it is supposed
        # to onload itself. In the case of using prefetching with streams, we onload the next group if
        # it is not supposed to onload itself.
        if self.group.onload_leader == module:
            if self.group.onload_self:
                self.group.onload_()
310
311
312

            should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
            if should_onload_next_group:
Aryan's avatar
Aryan committed
313
314
                self.next_group.onload_()

315
316
317
318
319
320
321
322
323
324
325
326
            should_synchronize = (
                not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
            )
            if should_synchronize:
                # If this group didn't onload itself, it means it was asynchronously onloaded by the
                # previous group. We need to synchronize the side stream to ensure parameters
                # are completely loaded to proceed with forward pass. Without this, uninitialized
                # weights will be used in the computation, leading to incorrect results
                # Also, we should only do this synchronization if we don't already do it from the sync call in
                # self.next_group.onload_, hence the `not should_onload_next_group` check.
                self.group.stream.synchronize()

Aryan's avatar
Aryan committed
327
        args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

        # Some Autoencoder models use a feature cache that is passed through submodules
        # and modified in place. The `send_to_device` call returns a copy of this feature cache object
        # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
        exclude_kwargs = self.config.exclude_kwargs or []
        if exclude_kwargs:
            moved_kwargs = send_to_device(
                {k: v for k, v in kwargs.items() if k not in exclude_kwargs},
                self.group.onload_device,
                non_blocking=self.group.non_blocking,
            )
            kwargs.update(moved_kwargs)
        else:
            kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

Aryan's avatar
Aryan committed
343
344
345
346
347
348
349
350
351
352
        return args, kwargs

    def post_forward(self, module: torch.nn.Module, output):
        if self.group.offload_leader == module:
            self.group.offload_()
        return output


class LazyPrefetchGroupOffloadingHook(ModelHook):
    r"""
353
    A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
Aryan's avatar
Aryan committed
354
355
356
357
358
359
360
361
362
363
364
365
    This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
    invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
    prefetching groups in the correct order.
    """

    _is_stateful = False

    def __init__(self):
        self.execution_order: List[Tuple[str, torch.nn.Module]] = []
        self._layer_execution_tracker_module_names = set()

    def initialize_hook(self, module):
366
367
        def make_execution_order_update_callback(current_name, current_submodule):
            def callback():
368
369
                if not torch.compiler.is_compiling():
                    logger.debug(f"Adding {current_name} to the execution order")
370
371
372
373
                self.execution_order.append((current_name, current_submodule))

            return callback

Aryan's avatar
Aryan committed
374
375
376
377
378
379
380
381
382
383
384
        # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
        # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
        # layers are executed during the forward pass.
        for name, submodule in module.named_modules():
            if name == "" or not hasattr(submodule, "_diffusers_hook"):
                continue

            registry = HookRegistry.check_if_exists_or_initialize(submodule)
            group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)

            if group_offloading_hook is not None:
385
386
                # For the first forward pass, we have to load in a blocking manner
                group_offloading_hook.group.non_blocking = False
Aryan's avatar
Aryan committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
                registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
                self._layer_execution_tracker_module_names.add(name)

        return module

    def post_forward(self, module, output):
        # At this point, for the current modules' submodules, we know the execution order of the layers. We can now
        # remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
        # group offloading hook.
        num_executed = len(self.execution_order)
        execution_order_module_names = {name for name, _ in self.execution_order}

        # It may be possible that some layers were not executed during the forward pass. This can happen if the layer
        # is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
        # may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
        # if the missing layers end up being executed in the future.
        if execution_order_module_names != self._layer_execution_tracker_module_names:
            unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
406
407
408
409
410
411
412
            if not torch.compiler.is_compiling():
                logger.warning(
                    "It seems like some layers were not executed during the forward pass. This may lead to problems when "
                    "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
                    "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
                    f"{unexecuted_layers=}"
                )
Aryan's avatar
Aryan committed
413
414
415
416

        # Remove the layer execution tracker hooks from the submodules
        base_module_registry = module._diffusers_hook
        registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
417
        group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
Aryan's avatar
Aryan committed
418
419
420
421
422
423
424

        for i in range(num_executed):
            registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)

        # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
        base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)

425
426
427
428
429
430
431
        # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
        # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
        # see the benefits of prefetching.
        for hook in group_offloading_hooks:
            hook.group.non_blocking = True

        # Set required attributes for prefetching
Aryan's avatar
Aryan committed
432
433
434
435
436
437
438
439
        if num_executed > 0:
            base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
            base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
            base_module_group_offloading_hook.next_group.onload_self = False

        for i in range(num_executed - 1):
            name1, _ = self.execution_order[i]
            name2, _ = self.execution_order[i + 1]
440
441
            if not torch.compiler.is_compiling():
                logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
Aryan's avatar
Aryan committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
            group_offloading_hooks[i].next_group.onload_self = False

        return output


class LayerExecutionTrackerHook(ModelHook):
    r"""
    A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
    LazyPrefetchGroupOffloadingHook to update the execution order.
    """

    _is_stateful = False

    def __init__(self, execution_order_update_callback):
        self.execution_order_update_callback = execution_order_update_callback

    def pre_forward(self, module, *args, **kwargs):
        self.execution_order_update_callback()
        return args, kwargs


def apply_group_offloading(
    module: torch.nn.Module,
466
467
    onload_device: Union[str, torch.device],
    offload_device: Union[str, torch.device] = torch.device("cpu"),
468
    offload_type: Union[str, GroupOffloadingType] = "block_level",
Aryan's avatar
Aryan committed
469
470
471
    num_blocks_per_group: Optional[int] = None,
    non_blocking: bool = False,
    use_stream: bool = False,
472
    record_stream: bool = False,
473
    low_cpu_mem_usage: bool = False,
474
    offload_to_disk_path: Optional[str] = None,
475
476
    block_modules: Optional[List[str]] = None,
    exclude_kwargs: Optional[List[str]] = None,
Aryan's avatar
Aryan committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
) -> None:
    r"""
    Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
    where it is beneficial, we need to first provide some context on how other supported offloading methods work.

    Typically, offloading is done at two levels:
    - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
      works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
      when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
      but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
      the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
      pass.
    - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
      works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
      onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
      memory, but can be slower due to the excessive number of device synchronizations.

    Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
    (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
    offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
    reduced.

    Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
    overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
    is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
    the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
    Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.

    Args:
        module (`torch.nn.Module`):
            The module to which group offloading is applied.
        onload_device (`torch.device`):
            The device to which the group of modules are onloaded.
        offload_device (`torch.device`, defaults to `torch.device("cpu")`):
            The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
512
        offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
Aryan's avatar
Aryan committed
513
514
            The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
            "block_level".
515
516
517
        offload_to_disk_path (`str`, *optional*, defaults to `None`):
            The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
            RAM environment settings where a reasonable speed-memory trade-off is desired.
Aryan's avatar
Aryan committed
518
519
520
521
522
523
524
525
        num_blocks_per_group (`int`, *optional*):
            The number of blocks per group when using offload_type="block_level". This is required when using
            offload_type="block_level".
        non_blocking (`bool`, defaults to `False`):
            If True, offloading and onloading is done with non-blocking data transfer.
        use_stream (`bool`, defaults to `False`):
            If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
            overlapping computation and data transfer.
526
527
528
529
        record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
            as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
            [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
            details.
530
531
532
533
        low_cpu_mem_usage (`bool`, defaults to `False`):
            If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
            option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
            the CPU memory is a bottleneck but may counteract the benefits of using streams.
534
535
536
537
538
539
540
        block_modules (`List[str]`, *optional*):
            List of module names that should be treated as blocks for offloading. If provided, only these modules will
            be considered for block-level offloading. If not provided, the default block detection logic will be used.
        exclude_kwargs (`List[str]`, *optional*):
            List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
            caching lists that need to maintain their object identity across forward passes. If not provided, will be
            inferred from the module's `_skip_keys` attribute if it exists.
Aryan's avatar
Aryan committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

    Example:
        ```python
        >>> from diffusers import CogVideoXTransformer3DModel
        >>> from diffusers.hooks import apply_group_offloading

        >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
        ...     "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
        ... )

        >>> apply_group_offloading(
        ...     transformer,
        ...     onload_device=torch.device("cuda"),
        ...     offload_device=torch.device("cpu"),
        ...     offload_type="block_level",
        ...     num_blocks_per_group=2,
        ...     use_stream=True,
        ... )
        ```
    """

562
563
    onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
    offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
564
565
    offload_type = GroupOffloadingType(offload_type)

Aryan's avatar
Aryan committed
566
567
568
569
    stream = None
    if use_stream:
        if torch.cuda.is_available():
            stream = torch.cuda.Stream()
570
571
        elif hasattr(torch, "xpu") and torch.xpu.is_available():
            stream = torch.Stream()
Aryan's avatar
Aryan committed
572
        else:
573
            raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
Aryan's avatar
Aryan committed
574

575
576
    if not use_stream and record_stream:
        raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
577
578
    if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
        raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
579

Aryan's avatar
Aryan committed
580
581
    _raise_error_if_accelerate_model_or_sequential_hook_present(module)

582
583
584
585
586
587
    if block_modules is None:
        block_modules = getattr(module, "_group_offload_block_modules", None)

    if exclude_kwargs is None:
        exclude_kwargs = getattr(module, "_skip_keys", None)

588
589
590
591
592
593
594
595
596
597
    config = GroupOffloadingConfig(
        onload_device=onload_device,
        offload_device=offload_device,
        offload_type=offload_type,
        num_blocks_per_group=num_blocks_per_group,
        non_blocking=non_blocking,
        stream=stream,
        record_stream=record_stream,
        low_cpu_mem_usage=low_cpu_mem_usage,
        offload_to_disk_path=offload_to_disk_path,
598
599
        block_modules=block_modules,
        exclude_kwargs=exclude_kwargs,
600
601
602
603
604
605
606
607
608
    )
    _apply_group_offloading(module, config)


def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
    if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
        _apply_group_offloading_block_level(module, config)
    elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
        _apply_group_offloading_leaf_level(module, config)
Aryan's avatar
Aryan committed
609
    else:
610
        assert False
Aryan's avatar
Aryan committed
611
612


613
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Aryan's avatar
Aryan committed
614
    r"""
615
616
617
    This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
    defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
    done at the top-level blocks and modules specified in block_modules.
618

619
620
621
    When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
    module, recursively apply block offloading to it.
    """
622
    if config.stream is not None and config.num_blocks_per_group != 1:
623
        logger.warning(
624
            f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
625
        )
626
        config.num_blocks_per_group = 1
Aryan's avatar
Aryan committed
627

628
629
630
    block_modules = set(config.block_modules) if config.block_modules is not None else set()

    # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
Aryan's avatar
Aryan committed
631
632
633
    modules_with_group_offloading = set()
    unmatched_modules = []
    matched_module_groups = []
634

Aryan's avatar
Aryan committed
635
    for name, submodule in module.named_children():
636
637
638
639
640
641
642
643
644
        # Check if this is an explicitly defined block module
        if name in block_modules:
            # Track submodule using a prefix to avoid filename collisions during disk offload.
            # Without this, submodules sharing the same model class would be assigned identical
            # filenames (derived from the class name).
            prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
            submodule_config = replace(config, module_prefix=prefix)

            _apply_group_offloading_block_level(submodule, submodule_config)
Aryan's avatar
Aryan committed
645
            modules_with_group_offloading.add(name)
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

        elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
            # Handle ModuleList and Sequential blocks as before
            for i in range(0, len(submodule), config.num_blocks_per_group):
                current_modules = list(submodule[i : i + config.num_blocks_per_group])
                if len(current_modules) == 0:
                    continue

                group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
                group = ModuleGroup(
                    modules=current_modules,
                    offload_device=config.offload_device,
                    onload_device=config.onload_device,
                    offload_to_disk_path=config.offload_to_disk_path,
                    offload_leader=current_modules[-1],
                    onload_leader=current_modules[0],
                    non_blocking=config.non_blocking,
                    stream=config.stream,
                    record_stream=config.record_stream,
                    low_cpu_mem_usage=config.low_cpu_mem_usage,
                    onload_self=True,
                    group_id=group_id,
                )
                matched_module_groups.append(group)
                for j in range(i, i + len(current_modules)):
                    modules_with_group_offloading.add(f"{name}.{j}")
        else:
            # This is an unmatched module
            unmatched_modules.append((name, submodule))
Aryan's avatar
Aryan committed
675
676
677
678

    # Apply group offloading hooks to the module groups
    for i, group in enumerate(matched_module_groups):
        for group_module in group.modules:
679
            _apply_group_offloading_hook(group_module, group, config=config)
Aryan's avatar
Aryan committed
680
681
682
683
684
685
686
687
688

    # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
    # when the forward pass of this module is called. This is because the top-level module is not
    # part of any group (as doing so would lead to no VRAM savings).
    parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
    buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
    parameters = [param for _, param in parameters]
    buffers = [buffer for _, buffer in buffers]

689
690
    # Create a group for the remaining unmatched submodules of the top-level
    # module so that they are on the correct device when the forward pass is called.
Aryan's avatar
Aryan committed
691
    unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
        unmatched_group = ModuleGroup(
            modules=unmatched_modules,
            offload_device=config.offload_device,
            onload_device=config.onload_device,
            offload_to_disk_path=config.offload_to_disk_path,
            offload_leader=module,
            onload_leader=module,
            parameters=parameters,
            buffers=buffers,
            non_blocking=False,
            stream=None,
            record_stream=False,
            onload_self=True,
            group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
        )
        if config.stream is None:
            _apply_group_offloading_hook(module, unmatched_group, config=config)
        else:
            _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
Aryan's avatar
Aryan committed
712
713


714
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Aryan's avatar
Aryan committed
715
716
717
718
719
720
721
722
723
    r"""
    This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
    requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
    synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
    reduce memory usage without any performance degradation.
    """
    # Create module groups for leaf modules and apply group offloading hooks
    modules_with_group_offloading = set()
    for name, submodule in module.named_modules():
724
        if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
Aryan's avatar
Aryan committed
725
726
727
            continue
        group = ModuleGroup(
            modules=[submodule],
728
729
730
            offload_device=config.offload_device,
            onload_device=config.onload_device,
            offload_to_disk_path=config.offload_to_disk_path,
Aryan's avatar
Aryan committed
731
732
            offload_leader=submodule,
            onload_leader=submodule,
733
734
735
736
            non_blocking=config.non_blocking,
            stream=config.stream,
            record_stream=config.record_stream,
            low_cpu_mem_usage=config.low_cpu_mem_usage,
Aryan's avatar
Aryan committed
737
            onload_self=True,
738
            group_id=name,
Aryan's avatar
Aryan committed
739
        )
740
        _apply_group_offloading_hook(submodule, group, config=config)
Aryan's avatar
Aryan committed
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        modules_with_group_offloading.add(name)

    # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
    # of the module is called
    module_dict = dict(module.named_modules())
    parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
    buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

    # Find closest module parent for each parameter and buffer, and attach group hooks
    parent_to_parameters = {}
    for name, param in parameters:
        parent_name = _find_parent_module_in_module_dict(name, module_dict)
        if parent_name in parent_to_parameters:
            parent_to_parameters[parent_name].append(param)
        else:
            parent_to_parameters[parent_name] = [param]

    parent_to_buffers = {}
    for name, buffer in buffers:
        parent_name = _find_parent_module_in_module_dict(name, module_dict)
        if parent_name in parent_to_buffers:
            parent_to_buffers[parent_name].append(buffer)
        else:
            parent_to_buffers[parent_name] = [buffer]

    parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
    for name in parent_names:
        parameters = parent_to_parameters.get(name, [])
        buffers = parent_to_buffers.get(name, [])
        parent_module = module_dict[name]
        group = ModuleGroup(
            modules=[],
773
774
            offload_device=config.offload_device,
            onload_device=config.onload_device,
Aryan's avatar
Aryan committed
775
776
            offload_leader=parent_module,
            onload_leader=parent_module,
777
            offload_to_disk_path=config.offload_to_disk_path,
Aryan's avatar
Aryan committed
778
779
            parameters=parameters,
            buffers=buffers,
780
781
782
783
            non_blocking=config.non_blocking,
            stream=config.stream,
            record_stream=config.record_stream,
            low_cpu_mem_usage=config.low_cpu_mem_usage,
Aryan's avatar
Aryan committed
784
            onload_self=True,
785
            group_id=name,
Aryan's avatar
Aryan committed
786
        )
787
        _apply_group_offloading_hook(parent_module, group, config=config)
Aryan's avatar
Aryan committed
788

789
    if config.stream is not None:
Aryan's avatar
Aryan committed
790
791
792
793
794
        # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
        # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
        # execution order and apply prefetching in the correct order.
        unmatched_group = ModuleGroup(
            modules=[],
795
796
797
            offload_device=config.offload_device,
            onload_device=config.onload_device,
            offload_to_disk_path=config.offload_to_disk_path,
Aryan's avatar
Aryan committed
798
799
800
801
802
803
            offload_leader=module,
            onload_leader=module,
            parameters=None,
            buffers=None,
            non_blocking=False,
            stream=None,
804
            record_stream=False,
805
            low_cpu_mem_usage=config.low_cpu_mem_usage,
Aryan's avatar
Aryan committed
806
            onload_self=True,
807
            group_id=_GROUP_ID_LAZY_LEAF,
Aryan's avatar
Aryan committed
808
        )
809
        _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
Aryan's avatar
Aryan committed
810
811
812
813
814


def _apply_group_offloading_hook(
    module: torch.nn.Module,
    group: ModuleGroup,
815
816
    *,
    config: GroupOffloadingConfig,
Aryan's avatar
Aryan committed
817
818
819
820
821
822
) -> None:
    registry = HookRegistry.check_if_exists_or_initialize(module)

    # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
    # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
    if registry.get_hook(_GROUP_OFFLOADING) is None:
823
        hook = GroupOffloadingHook(group, config=config)
Aryan's avatar
Aryan committed
824
825
826
827
828
829
        registry.register_hook(hook, _GROUP_OFFLOADING)


def _apply_lazy_group_offloading_hook(
    module: torch.nn.Module,
    group: ModuleGroup,
830
831
    *,
    config: GroupOffloadingConfig,
Aryan's avatar
Aryan committed
832
833
834
835
836
837
) -> None:
    registry = HookRegistry.check_if_exists_or_initialize(module)

    # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
    # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
    if registry.get_hook(_GROUP_OFFLOADING) is None:
838
        hook = GroupOffloadingHook(group, config=config)
Aryan's avatar
Aryan committed
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
        registry.register_hook(hook, _GROUP_OFFLOADING)

    lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
    registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)


def _gather_parameters_with_no_group_offloading_parent(
    module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]:
    parameters = []
    for name, parameter in module.named_parameters():
        has_parent_with_group_offloading = False
        atoms = name.split(".")
        while len(atoms) > 0:
            parent_name = ".".join(atoms)
            if parent_name in modules_with_group_offloading:
                has_parent_with_group_offloading = True
                break
            atoms.pop()
        if not has_parent_with_group_offloading:
            parameters.append((name, parameter))
    return parameters


def _gather_buffers_with_no_group_offloading_parent(
    module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.Tensor]:
    buffers = []
    for name, buffer in module.named_buffers():
        has_parent_with_group_offloading = False
        atoms = name.split(".")
        while len(atoms) > 0:
            parent_name = ".".join(atoms)
            if parent_name in modules_with_group_offloading:
                has_parent_with_group_offloading = True
                break
            atoms.pop()
        if not has_parent_with_group_offloading:
            buffers.append((name, buffer))
    return buffers


def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
    atoms = name.split(".")
    while len(atoms) > 0:
        parent_name = ".".join(atoms)
        if parent_name in module_dict:
            return parent_name
        atoms.pop()
    return ""


def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
    if not is_accelerate_available():
        return
    for name, submodule in module.named_modules():
        if not hasattr(submodule, "_hf_hook"):
            continue
        if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
            raise ValueError(
                f"Cannot apply group offloading to a module that is already applying an alternative "
                f"offloading strategy from Accelerate. If you want to apply group offloading, please "
                f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
            )


905
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
Aryan's avatar
Aryan committed
906
    for submodule in module.modules():
907
908
909
910
911
912
913
914
915
916
        if hasattr(submodule, "_diffusers_hook"):
            group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
            if group_offloading_hook is not None:
                return group_offloading_hook
    return None


def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
    top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
    return top_level_group_offload_hook is not None
Aryan's avatar
Aryan committed
917
918
919


def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
920
921
922
    top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
    if top_level_group_offload_hook is not None:
        return top_level_group_offload_hook.config.onload_device
Aryan's avatar
Aryan committed
923
    raise ValueError("Group offloading is not enabled for the provided module.")
924
925


926
927
928
929
930
931
def _compute_group_hash(group_id):
    hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
    # first 16 characters for a reasonably short but unique name
    return hashed_id[:16]


932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
    r"""
    Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
    modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
    modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.

    In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
    and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
    case where user has applied group offloading at multiple levels, this function will not work as expected.

    There is some performance penalty associated with doing this when non-default streams are used, because we need to
    retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
    """
    top_level_group_offload_hook = _get_top_level_group_offload_hook(module)

    if top_level_group_offload_hook is None:
        return

    registry = HookRegistry.check_if_exists_or_initialize(module)
    registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
    registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
    registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)

    _apply_group_offloading(module, top_level_group_offload_hook.config)