Unverified Commit f12d161d authored by swappy's avatar swappy Committed by GitHub
Browse files

Fix broken group offloading with block_level for models with standalone layers (#12692)



* fix: group offloading to support standalone computational layers in block-level offloading

* test: for models with standalone and deeply nested layers in block-level offloading

* feat: support for block-level offloading in group offloading config

* fix: group offload block modules to AutoencoderKL and AutoencoderKLWan

* fix: update group offloading tests to use AutoencoderKL and adjust input dimensions

* refactor: streamline block offloading logic

* Apply style fixes

* update tests

* update

* fix for failing tests

* clean up

* revert to use skip_keys

* clean up

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 8d415a6f
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import hashlib import hashlib
import os import os
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
...@@ -59,6 +59,9 @@ class GroupOffloadingConfig: ...@@ -59,6 +59,9 @@ class GroupOffloadingConfig:
num_blocks_per_group: Optional[int] = None num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None
exclude_kwargs: Optional[List[str]] = None
module_prefix: Optional[str] = ""
class ModuleGroup: class ModuleGroup:
...@@ -77,7 +80,7 @@ class ModuleGroup: ...@@ -77,7 +80,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
onload_self: bool = True, onload_self: bool = True,
offload_to_disk_path: Optional[str] = None, offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None, group_id: Optional[Union[int, str]] = None,
) -> None: ) -> None:
self.modules = modules self.modules = modules
self.offload_device = offload_device self.offload_device = offload_device
...@@ -322,7 +325,21 @@ class GroupOffloadingHook(ModelHook): ...@@ -322,7 +325,21 @@ class GroupOffloadingHook(ModelHook):
self.group.stream.synchronize() self.group.stream.synchronize()
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
# Some Autoencoder models use a feature cache that is passed through submodules
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
exclude_kwargs = self.config.exclude_kwargs or []
if exclude_kwargs:
moved_kwargs = send_to_device(
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
self.group.onload_device,
non_blocking=self.group.non_blocking,
)
kwargs.update(moved_kwargs)
else:
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs return args, kwargs
def post_forward(self, module: torch.nn.Module, output): def post_forward(self, module: torch.nn.Module, output):
...@@ -455,6 +472,8 @@ def apply_group_offloading( ...@@ -455,6 +472,8 @@ def apply_group_offloading(
record_stream: bool = False, record_stream: bool = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None, offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
exclude_kwargs: Optional[List[str]] = None,
) -> None: ) -> None:
r""" r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
...@@ -512,6 +531,13 @@ def apply_group_offloading( ...@@ -512,6 +531,13 @@ def apply_group_offloading(
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams. the CPU memory is a bottleneck but may counteract the benefits of using streams.
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules will
be considered for block-level offloading. If not provided, the default block detection logic will be used.
exclude_kwargs (`List[str]`, *optional*):
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
caching lists that need to maintain their object identity across forward passes. If not provided, will be
inferred from the module's `_skip_keys` attribute if it exists.
Example: Example:
```python ```python
...@@ -553,6 +579,12 @@ def apply_group_offloading( ...@@ -553,6 +579,12 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module) _raise_error_if_accelerate_model_or_sequential_hook_present(module)
if block_modules is None:
block_modules = getattr(module, "_group_offload_block_modules", None)
if exclude_kwargs is None:
exclude_kwargs = getattr(module, "_skip_keys", None)
config = GroupOffloadingConfig( config = GroupOffloadingConfig(
onload_device=onload_device, onload_device=onload_device,
offload_device=offload_device, offload_device=offload_device,
...@@ -563,6 +595,8 @@ def apply_group_offloading( ...@@ -563,6 +595,8 @@ def apply_group_offloading(
record_stream=record_stream, record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
) )
_apply_group_offloading(module, config) _apply_group_offloading(module, config)
...@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf ...@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r""" r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
""" done at the top-level blocks and modules specified in block_modules.
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, recursively apply block offloading to it.
"""
if config.stream is not None and config.num_blocks_per_group != 1: if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning( logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
) )
config.num_blocks_per_group = 1 config.num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks block_modules = set(config.block_modules) if config.block_modules is not None else set()
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set() modules_with_group_offloading = set()
unmatched_modules = [] unmatched_modules = []
matched_module_groups = [] matched_module_groups = []
for name, submodule in module.named_children(): for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): # Check if this is an explicitly defined block module
unmatched_modules.append((name, submodule)) if name in block_modules:
# Track submodule using a prefix to avoid filename collisions during disk offload.
# Without this, submodules sharing the same model class would be assigned identical
# filenames (derived from the class name).
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
submodule_config = replace(config, module_prefix=prefix)
_apply_group_offloading_block_level(submodule, submodule_config)
modules_with_group_offloading.add(name) modules_with_group_offloading.add(name)
continue
for i in range(0, len(submodule), config.num_blocks_per_group): elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
current_modules = submodule[i : i + config.num_blocks_per_group] # Handle ModuleList and Sequential blocks as before
group_id = f"{name}_{i}_{i + len(current_modules) - 1}" for i in range(0, len(submodule), config.num_blocks_per_group):
group = ModuleGroup( current_modules = list(submodule[i : i + config.num_blocks_per_group])
modules=current_modules, if len(current_modules) == 0:
offload_device=config.offload_device, continue
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path, group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
offload_leader=current_modules[-1], group = ModuleGroup(
onload_leader=current_modules[0], modules=current_modules,
non_blocking=config.non_blocking, offload_device=config.offload_device,
stream=config.stream, onload_device=config.onload_device,
record_stream=config.record_stream, offload_to_disk_path=config.offload_to_disk_path,
low_cpu_mem_usage=config.low_cpu_mem_usage, offload_leader=current_modules[-1],
onload_self=True, onload_leader=current_modules[0],
group_id=group_id, non_blocking=config.non_blocking,
) stream=config.stream,
matched_module_groups.append(group) record_stream=config.record_stream,
for j in range(i, i + len(current_modules)): low_cpu_mem_usage=config.low_cpu_mem_usage,
modules_with_group_offloading.add(f"{name}.{j}") onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))
# Apply group offloading hooks to the module groups # Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups): for i, group in enumerate(matched_module_groups):
...@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf ...@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
parameters = [param for _, param in parameters] parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers] buffers = [buffer for _, buffer in buffers]
# Create a group for the unmatched submodules of the top-level module so that they are on the correct # Create a group for the remaining unmatched submodules of the top-level
# device when the forward pass is called. # module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup( if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
modules=unmatched_modules, unmatched_group = ModuleGroup(
offload_device=config.offload_device, modules=unmatched_modules,
onload_device=config.onload_device, offload_device=config.offload_device,
offload_to_disk_path=config.offload_to_disk_path, onload_device=config.onload_device,
offload_leader=module, offload_to_disk_path=config.offload_to_disk_path,
onload_leader=module, offload_leader=module,
parameters=parameters, onload_leader=module,
buffers=buffers, parameters=parameters,
non_blocking=False, buffers=buffers,
stream=None, non_blocking=False,
record_stream=False, stream=None,
onload_self=True, record_stream=False,
group_id=f"{module.__class__.__name__}_unmatched_group", onload_self=True,
) group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
if config.stream is None: )
_apply_group_offloading_hook(module, unmatched_group, config=config) if config.stream is None:
else: _apply_group_offloading_hook(module, unmatched_group, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config) else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
......
...@@ -74,6 +74,7 @@ class AutoencoderKL( ...@@ -74,6 +74,7 @@ class AutoencoderKL(
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -619,6 +619,7 @@ class WanEncoder3d(nn.Module): ...@@ -619,6 +619,7 @@ class WanEncoder3d(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = self.conv_out(x) x = self.conv_out(x)
return x return x
...@@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo ...@@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
""" """
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
# keys toignore when AlignDeviceHook moves inputs/outputs between devices # keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place # these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"] _skip_keys = ["feat_cache", "feat_idx"]
...@@ -1414,6 +1416,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo ...@@ -1414,6 +1416,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
""" """
x = sample x = sample
posterior = self.encode(x).latent_dist posterior = self.encode(x).latent_dist
if sample_posterior: if sample_posterior:
z = posterior.sample(generator=generator) z = posterior.sample(generator=generator)
else: else:
......
...@@ -531,6 +531,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -531,6 +531,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream: bool = False, record_stream: bool = False,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None, offload_to_disk_path: Optional[str] = None,
block_modules: Optional[str] = None,
exclude_kwargs: Optional[str] = None,
) -> None: ) -> None:
r""" r"""
Activates group offloading for the current model. Activates group offloading for the current model.
...@@ -570,6 +572,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -570,6 +572,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues." f"open an issue at https://github.com/huggingface/diffusers/issues."
) )
apply_group_offloading( apply_group_offloading(
module=self, module=self,
onload_device=onload_device, onload_device=onload_device,
...@@ -581,6 +584,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -581,6 +584,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream=record_stream, record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path, offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
) )
def set_attention_backend(self, backend: str) -> None: def set_attention_backend(self, backend: str) -> None:
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
...@@ -149,6 +150,74 @@ class LayerOutputTrackerHook(ModelHook): ...@@ -149,6 +150,74 @@ class LayerOutputTrackerHook(ModelHook):
return output return output
# Model with only standalone computational layers at top level
class DummyModelWithStandaloneLayers(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.layer1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(hidden_features, hidden_features)
self.layer3 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# Model with deeply nested structure
class DummyModelWithDeeplyNestedBlocks(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.input_layer = torch.nn.Linear(in_features, hidden_features)
self.container = ContainerWithNestedModuleList(hidden_features)
self.output_layer = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.input_layer(x)
x = self.container(x)
x = self.output_layer(x)
return x
class ContainerWithNestedModuleList(torch.nn.Module):
def __init__(self, features: int) -> None:
super().__init__()
# Top-level computational layer
self.proj_in = torch.nn.Linear(features, features)
# Nested container with ModuleList
self.nested_container = NestedContainer(features)
# Another top-level computational layer
self.proj_out = torch.nn.Linear(features, features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.nested_container(x)
x = self.proj_out(x)
return x
class NestedContainer(torch.nn.Module):
def __init__(self, features: int) -> None:
super().__init__()
self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)])
self.norm = torch.nn.LayerNorm(features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
@require_torch_accelerator @require_torch_accelerator
class GroupOffloadTests(unittest.TestCase): class GroupOffloadTests(unittest.TestCase):
in_features = 64 in_features = 64
...@@ -340,7 +409,7 @@ class GroupOffloadTests(unittest.TestCase): ...@@ -340,7 +409,7 @@ class GroupOffloadTests(unittest.TestCase):
out = model(x) out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.") self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4 num_repeats = 2
for i in range(num_repeats): for i in range(num_repeats):
out_ref = model_ref(x) out_ref = model_ref(x)
out = model(x) out = model(x)
...@@ -362,3 +431,138 @@ class GroupOffloadTests(unittest.TestCase): ...@@ -362,3 +431,138 @@ class GroupOffloadTests(unittest.TestCase):
self.assertLess( self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
) )
def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)
x = torch.randn(2, 3, 32, 32).to(torch_device)
with torch.no_grad():
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)
def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 64).to(torch_device)
with torch.no_grad():
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for model with standalone layers.",
)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 3, 32, 32).to(torch_device)
with torch.no_grad():
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match for standalone Conv layers with {offload_type}.",
)
def test_multiple_invocations_with_vae_like_model(self):
"""Test that multiple forward passes work correctly with VAE-like model."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 3, 32, 32).to(torch_device)
with torch.no_grad():
for i in range(2):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
def test_nested_container_parameters_offloading(self):
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 64).to(torch_device)
with torch.no_grad():
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
)
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
"layers_per_block": 1,
}
return init_dict
...@@ -1791,7 +1791,6 @@ class ModelTesterMixin: ...@@ -1791,7 +1791,6 @@ class ModelTesterMixin:
return model(**inputs_dict)[0] return model(**inputs_dict)[0]
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
output_without_group_offloading = run_forward(model) output_without_group_offloading = run_forward(model)
output_without_group_offloading = normalize_output(output_without_group_offloading) output_without_group_offloading = normalize_output(output_without_group_offloading)
...@@ -1916,6 +1915,9 @@ class ModelTesterMixin: ...@@ -1916,6 +1915,9 @@ class ModelTesterMixin:
offload_to_disk_path=tmpdir, offload_to_disk_path=tmpdir,
offload_type=offload_type, offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group, num_blocks_per_group=num_blocks_per_group,
block_modules=model._group_offload_block_modules
if hasattr(model, "_group_offload_block_modules")
else None,
) )
if not is_correct: if not is_correct:
if extra_files: if extra_files:
......
...@@ -1424,6 +1424,8 @@ if is_torch_available(): ...@@ -1424,6 +1424,8 @@ if is_torch_available():
offload_to_disk_path: str, offload_to_disk_path: str,
offload_type: str, offload_type: str,
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
block_modules: Optional[List[str]] = None,
module_prefix: str = "",
) -> Set[str]: ) -> Set[str]:
expected_files = set() expected_files = set()
...@@ -1435,23 +1437,36 @@ if is_torch_available(): ...@@ -1435,23 +1437,36 @@ if is_torch_available():
if num_blocks_per_group is None: if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
# Handle groups of ModuleList and Sequential blocks block_modules_set = set(block_modules) if block_modules is not None else set()
modules_with_group_offloading = set()
unmatched_modules = [] unmatched_modules = []
for name, submodule in module.named_children(): for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): if name in block_modules_set:
unmatched_modules.append(module) new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
continue submodule_files = _get_expected_safetensors_files(
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
)
expected_files.update(submodule_files)
modules_with_group_offloading.add(name)
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
unmatched_modules.append(submodule)
for i in range(0, len(submodule), num_blocks_per_group): parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
current_modules = submodule[i : i + num_blocks_per_group] buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
if not current_modules:
continue
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
# Handle the group for unmatched top-level modules and parameters if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
for module in unmatched_modules: expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
elif offload_type == "leaf_level": elif offload_type == "leaf_level":
# Handle leaf-level module groups # Handle leaf-level module groups
...@@ -1492,12 +1507,13 @@ if is_torch_available(): ...@@ -1492,12 +1507,13 @@ if is_torch_available():
offload_to_disk_path: str, offload_to_disk_path: str,
offload_type: str, offload_type: str,
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
block_modules: Optional[List[str]] = None,
) -> bool: ) -> bool:
if not os.path.isdir(offload_to_disk_path): if not os.path.isdir(offload_to_disk_path):
return False, None, None return False, None, None
expected_files = _get_expected_safetensors_files( expected_files = _get_expected_safetensors_files(
module, offload_to_disk_path, offload_type, num_blocks_per_group module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules
) )
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
missing_files = expected_files - actual_files missing_files = expected_files - actual_files
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment