Unverified Commit 76ec3d1f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Support dynamically loading/unloading loras with group offloading (#11804)

* update

* add test

* address review comments

* update

* fixes

* change decorator order to fix tests

* try fix

* fight tests
parent cdaf84a7
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import os import os
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch import safetensors.torch
...@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = ( ...@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = (
# fmt: on # fmt: on
class GroupOffloadingType(str, Enum):
BLOCK_LEVEL = "block_level"
LEAF_LEVEL = "leaf_level"
@dataclass
class GroupOffloadingConfig:
onload_device: torch.device
offload_device: torch.device
offload_type: GroupOffloadingType
non_blocking: bool
record_stream: bool
low_cpu_mem_usage: bool
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
class ModuleGroup: class ModuleGroup:
def __init__( def __init__(
self, self,
...@@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook): ...@@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False _is_stateful = False
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
self.group = group self.group = group
self.next_group = next_group self.next_group = next_group
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module: if self.group.offload_leader == module:
...@@ -436,7 +459,7 @@ def apply_group_offloading( ...@@ -436,7 +459,7 @@ def apply_group_offloading(
module: torch.nn.Module, module: torch.nn.Module,
onload_device: torch.device, onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"), offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level", offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
...@@ -478,7 +501,7 @@ def apply_group_offloading( ...@@ -478,7 +501,7 @@ def apply_group_offloading(
The device to which the group of modules are onloaded. The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`): offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
offload_type (`str`, defaults to "block_level"): offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level". "block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`): offload_to_disk_path (`str`, *optional*, defaults to `None`):
...@@ -521,6 +544,8 @@ def apply_group_offloading( ...@@ -521,6 +544,8 @@ def apply_group_offloading(
``` ```
""" """
offload_type = GroupOffloadingType(offload_type)
stream = None stream = None
if use_stream: if use_stream:
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -532,84 +557,45 @@ def apply_group_offloading( ...@@ -532,84 +557,45 @@ def apply_group_offloading(
if not use_stream and record_stream: if not use_stream and record_stream:
raise ValueError("`record_stream` cannot be True when `use_stream=False`.") raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module) _raise_error_if_accelerate_model_or_sequential_hook_present(module)
if offload_type == "block_level": config = GroupOffloadingConfig(
if num_blocks_per_group is None: onload_device=onload_device,
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") offload_device=offload_device,
offload_type=offload_type,
_apply_group_offloading_block_level( num_blocks_per_group=num_blocks_per_group,
module=module, non_blocking=non_blocking,
num_blocks_per_group=num_blocks_per_group, stream=stream,
offload_device=offload_device, record_stream=record_stream,
onload_device=onload_device, low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking, )
stream=stream, _apply_group_offloading(module, config)
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
) def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
elif offload_type == "leaf_level": if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
_apply_group_offloading_leaf_level( _apply_group_offloading_block_level(module, config)
module=module, elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
offload_device=offload_device, _apply_group_offloading_leaf_level(module, config)
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else: else:
raise ValueError(f"Unsupported offload_type: {offload_type}") assert False
def _apply_group_offloading_block_level( def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r""" r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
""" """
if stream is not None and num_blocks_per_group != 1:
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning( logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1." f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
) )
num_blocks_per_group = 1 config.num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks # Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set() modules_with_group_offloading = set()
...@@ -621,19 +607,19 @@ def _apply_group_offloading_block_level( ...@@ -621,19 +607,19 @@ def _apply_group_offloading_block_level(
modules_with_group_offloading.add(name) modules_with_group_offloading.add(name)
continue continue
for i in range(0, len(submodule), num_blocks_per_group): for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group] current_modules = submodule[i : i + config.num_blocks_per_group]
group = ModuleGroup( group = ModuleGroup(
modules=current_modules, modules=current_modules,
offload_device=offload_device, offload_device=config.offload_device,
onload_device=onload_device, onload_device=config.onload_device,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1], offload_leader=current_modules[-1],
onload_leader=current_modules[0], onload_leader=current_modules[0],
non_blocking=non_blocking, non_blocking=config.non_blocking,
stream=stream, stream=config.stream,
record_stream=record_stream, record_stream=config.record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
matched_module_groups.append(group) matched_module_groups.append(group)
...@@ -643,7 +629,7 @@ def _apply_group_offloading_block_level( ...@@ -643,7 +629,7 @@ def _apply_group_offloading_block_level(
# Apply group offloading hooks to the module groups # Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups): for i, group in enumerate(matched_module_groups):
for group_module in group.modules: for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None) _apply_group_offloading_hook(group_module, group, None, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not # when the forward pass of this module is called. This is because the top-level module is not
...@@ -658,9 +644,9 @@ def _apply_group_offloading_block_level( ...@@ -658,9 +644,9 @@ def _apply_group_offloading_block_level(
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup( unmatched_group = ModuleGroup(
modules=unmatched_modules, modules=unmatched_modules,
offload_device=offload_device, offload_device=config.offload_device,
onload_device=onload_device, onload_device=config.onload_device,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module, offload_leader=module,
onload_leader=module, onload_leader=module,
parameters=parameters, parameters=parameters,
...@@ -670,54 +656,19 @@ def _apply_group_offloading_block_level( ...@@ -670,54 +656,19 @@ def _apply_group_offloading_block_level(
record_stream=False, record_stream=False,
onload_self=True, onload_self=True,
) )
if stream is None: if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None) _apply_group_offloading_hook(module, unmatched_group, None, config=config)
else: else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None) _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_leaf_level( def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r""" r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation. reduce memory usage without any performance degradation.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
""" """
# Create module groups for leaf modules and apply group offloading hooks # Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set() modules_with_group_offloading = set()
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
...@@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level( ...@@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level(
continue continue
group = ModuleGroup( group = ModuleGroup(
modules=[submodule], modules=[submodule],
offload_device=offload_device, offload_device=config.offload_device,
onload_device=onload_device, onload_device=config.onload_device,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule, offload_leader=submodule,
onload_leader=submodule, onload_leader=submodule,
non_blocking=non_blocking, non_blocking=config.non_blocking,
stream=stream, stream=config.stream,
record_stream=record_stream, record_stream=config.record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_group_offloading_hook(submodule, group, None) _apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name) modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
...@@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level( ...@@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level(
parameters = parent_to_parameters.get(name, []) parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, []) buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name] parent_module = module_dict[name]
assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup( group = ModuleGroup(
modules=[], modules=[],
offload_device=offload_device, offload_device=config.offload_device,
onload_device=onload_device, onload_device=config.onload_device,
offload_leader=parent_module, offload_leader=parent_module,
onload_leader=parent_module, onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=config.offload_to_disk_path,
parameters=parameters, parameters=parameters,
buffers=buffers, buffers=buffers,
non_blocking=non_blocking, non_blocking=config.non_blocking,
stream=stream, stream=config.stream,
record_stream=record_stream, record_stream=config.record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_group_offloading_hook(parent_module, group, None) _apply_group_offloading_hook(parent_module, group, None, config=config)
if stream is not None: if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order. # execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup( unmatched_group = ModuleGroup(
modules=[], modules=[],
offload_device=offload_device, offload_device=config.offload_device,
onload_device=onload_device, onload_device=config.onload_device,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module, offload_leader=module,
onload_leader=module, onload_leader=module,
parameters=None, parameters=None,
...@@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level( ...@@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level(
non_blocking=False, non_blocking=False,
stream=None, stream=None,
record_stream=False, record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_lazy_group_offloading_hook(module, unmatched_group, None) _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_hook( def _apply_group_offloading_hook(
module: torch.nn.Module, module: torch.nn.Module,
group: ModuleGroup, group: ModuleGroup,
next_group: Optional[ModuleGroup] = None, next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None: ) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module) registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook. # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None: if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group) hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING) registry.register_hook(hook, _GROUP_OFFLOADING)
...@@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook( ...@@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook(
module: torch.nn.Module, module: torch.nn.Module,
group: ModuleGroup, group: ModuleGroup,
next_group: Optional[ModuleGroup] = None, next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None: ) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module) registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook. # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None: if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group) hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING) registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
...@@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn ...@@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
) )
def _is_group_offload_enabled(module: torch.nn.Module) -> bool: def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
for submodule in module.modules(): for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: if hasattr(submodule, "_diffusers_hook"):
return True group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
return False if group_offloading_hook is not None:
return group_offloading_hook
return None
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
return top_level_group_offload_hook is not None
def _get_group_onload_device(module: torch.nn.Module) -> torch.device: def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
for submodule in module.modules(): top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: if top_level_group_offload_hook is not None:
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device return top_level_group_offload_hook.config.onload_device
raise ValueError("Group offloading is not enabled for the provided module.") raise ValueError("Group offloading is not enabled for the provided module.")
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
case where user has applied group offloading at multiple levels, this function will not work as expected.
There is some performance penalty associated with doing this when non-default streams are used, because we need to
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
"""
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is None:
return
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
_apply_group_offloading(module, top_level_group_offload_hook.config)
...@@ -25,6 +25,7 @@ import torch.nn as nn ...@@ -25,6 +25,7 @@ import torch.nn as nn
from huggingface_hub import model_info from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder( ...@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder) adapter_name = get_adapter_name(text_encoder)
# <Unsafe code # <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
)
# inject LoRA layers and load the state dict # inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not # in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter( text_encoder.load_adapter(
...@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder( ...@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
_pipeline.enable_model_cpu_offload() _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict: if prefix is not None and not state_dict:
...@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline): ...@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
Returns: Returns:
tuple: tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
""" """
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None: if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items(): for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not isinstance(component, nn.Module):
if not is_model_cpu_offload: continue
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) is_group_offload = is_group_offload or _is_group_offload_enabled(component)
if not is_sequential_cpu_offload: if not hasattr(component, "_hf_hook"):
is_sequential_cpu_offload = ( continue
isinstance(component._hf_hook, AlignDevicesHook) is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
or hasattr(component._hf_hook, "hooks") is_sequential_cpu_offload = is_sequential_cpu_offload or (
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) isinstance(component._hf_hook, AlignDevicesHook)
) or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info( if is_sequential_cpu_offload or is_model_cpu_offload:
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." logger.info(
) "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
if is_sequential_cpu_offload or is_model_cpu_offload: )
remove_hook_from_module(component, recurse=is_sequential_cpu_offload) for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
class LoraBaseMixin: class LoraBaseMixin:
......
...@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union ...@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors import safetensors
import torch import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import ( from ..utils import (
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -256,7 +257,9 @@ class PeftAdapterMixin: ...@@ -256,7 +257,9 @@ class PeftAdapterMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error. # otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {} peft_kwargs = {}
if is_peft_version(">=", "0.13.1"): if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
...@@ -347,6 +350,10 @@ class PeftAdapterMixin: ...@@ -347,6 +350,10 @@ class PeftAdapterMixin:
_pipeline.enable_model_cpu_offload() _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict: if prefix is not None and not state_dict:
...@@ -687,6 +694,8 @@ class PeftAdapterMixin: ...@@ -687,6 +694,8 @@ class PeftAdapterMixin:
if hasattr(self, "peft_config"): if hasattr(self, "peft_config"):
del self.peft_config del self.peft_config
_maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self): def disable_lora(self):
""" """
Disables the active LoRA layers of the underlying model. Disables the active LoRA layers of the underlying model.
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import ( from ..models.embeddings import (
ImageProjection, ImageProjection,
IPAdapterFaceIDImageProjection, IPAdapterFaceIDImageProjection,
...@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin: ...@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin:
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
is_group_offload = False
if is_lora: if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`." deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
...@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin: ...@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin:
if is_custom_diffusion: if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict) attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora: elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict, state_dict=state_dict,
unet_identifier_key=self.unet_name, unet_identifier_key=self.unet_name,
network_alphas=network_alphas, network_alphas=network_alphas,
...@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin: ...@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin:
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`. # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None: if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline=_pipeline
)
# only custom diffusion needs to set attn processors # only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors) self.set_attn_processor(attn_processors)
...@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin: ...@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin:
_pipeline.enable_model_cpu_offload() _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code /> # Unsafe code />
def _process_custom_diffusion(self, state_dict): def _process_custom_diffusion(self, state_dict):
...@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin: ...@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0: if len(state_dict_to_be_used) > 0:
...@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin: ...@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error # otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {} peft_kwargs = {}
if is_peft_version(">=", "0.13.1"): if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
...@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin: ...@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin:
if warn_msg: if warn_msg:
logger.warning(warn_msg) logger.warning(warn_msg)
return is_model_cpu_offload, is_sequential_cpu_offload return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
@classmethod @classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
......
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import unittest import unittest
import torch import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
...@@ -28,6 +29,7 @@ from diffusers import ( ...@@ -28,6 +29,7 @@ from diffusers import (
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
require_torch_accelerator,
) )
...@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_scale_kwargs_match_fusion(self): def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogVideoX.") @unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
......
...@@ -18,10 +18,17 @@ import unittest ...@@ -18,10 +18,17 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)
sys.path.append(".") sys.path.append(".")
...@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogView4.") @unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
......
...@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import ( ...@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
is_torch_version, is_torch_version,
require_peft_backend, require_peft_backend,
require_peft_version_greater, require_peft_version_greater,
require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
torch_device, torch_device,
...@@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests: ...@@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests:
pipe.load_lora_weights(tmpdirname) pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
onload_device = torch_device
offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
# Test group offloading with load_lora_weights
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Test group offloading after removing the lora
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_2 is not None)
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
# Add the lora again and check if group offloading works
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_3 is not None)
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment