Unverified Commit cfd6ec74 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] condense group offloading (#11990)

* update

* update

* refactor

* add test

* address review comment

* nit
parent 1082c46a
...@@ -95,7 +95,7 @@ class ModuleGroup: ...@@ -95,7 +95,7 @@ class ModuleGroup:
self.offload_to_disk_path = offload_to_disk_path self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False self._is_offloaded_to_disk = False
if self.offload_to_disk_path: if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self)) self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id) short_hash = _compute_group_hash(self.group_id)
...@@ -115,6 +115,12 @@ class ModuleGroup: ...@@ -115,6 +115,12 @@ class ModuleGroup:
else: else:
self.cpu_param_dict = self._init_cpu_param_dict() self.cpu_param_dict = self._init_cpu_param_dict()
self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
def _init_cpu_param_dict(self): def _init_cpu_param_dict(self):
cpu_param_dict = {} cpu_param_dict = {}
if self.stream is None: if self.stream is None:
...@@ -138,112 +144,76 @@ class ModuleGroup: ...@@ -138,112 +144,76 @@ class ModuleGroup:
@contextmanager @contextmanager
def _pinned_memory_tensors(self): def _pinned_memory_tensors(self):
pinned_dict = {}
try: try:
for param, tensor in self.cpu_param_dict.items(): pinned_dict = {
if not tensor.is_pinned(): param: tensor.pin_memory() if not tensor.is_pinned() else tensor
pinned_dict[param] = tensor.pin_memory() for param, tensor in self.cpu_param_dict.items()
else: }
pinned_dict[param] = tensor
yield pinned_dict yield pinned_dict
finally: finally:
pinned_dict = None pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): def _transfer_tensor_to_device(self, tensor, source_tensor):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None: if self.record_stream:
tensor.data.record_stream(current_stream) tensor.data.record_stream(self._torch_accelerator_module.current_stream())
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): def _process_tensors_from_modules(self, pinned_memory=None):
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream) self._transfer_tensor_to_device(param, source)
for buffer in group_module.buffers(): for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream) self._transfer_tensor_to_device(buffer, source)
for param in self.parameters: for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream) self._transfer_tensor_to_device(param, source)
for buffer in self.buffers: for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream) self._transfer_tensor_to_device(buffer, source)
def _onload_from_disk(self, current_stream): def _onload_from_disk(self):
if self.stream is not None: if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") # Wait for previous Host->Device transfer to complete
self.stream.synchronize()
for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
self.cpu_param_dict.clear()
else: context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
onload_device = ( current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self, current_stream): with context:
if self.stream is not None: # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
with self._pinned_memory_tensors() as pinned_memory: device = str(self.onload_device) if self.stream is None else "cpu"
self._process_tensors_from_modules(pinned_memory, current_stream) loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
else:
self._process_tensors_from_modules(None, current_stream)
@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None: if self.stream is not None:
# Wait for previous Host->Device transfer to complete for key, tensor_obj in self.key_to_tensor.items():
self.stream.synchronize() pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
with context: if self.record_stream:
if self.stream is not None: tensor_obj.data.record_stream(current_stream)
# Load to CPU, pin, and async copy to device for overlapping transfer and compute else:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") onload_device = (
for key, tensor_obj in self.key_to_tensor.items(): self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
pinned_tensor = loaded_cpu_tensors[key].pin_memory() )
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
if self.record_stream: for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data.record_stream(current_stream) tensor_obj.data = loaded_tensors[key]
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
def _onload_from_memory(self):
if self.stream is not None: if self.stream is not None:
# Wait for previous Host->Device transfer to complete # Wait for previous Host->Device transfer to complete
self.stream.synchronize() self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context: with context:
if self.offload_to_disk_path: if self.stream is not None:
self._onload_from_disk(current_stream) with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else: else:
self._onload_from_memory(current_stream) self._process_tensors_from_modules(None)
def _offload_to_disk(self): def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired # TODO: we can potentially optimize this code path by checking if the _all_ the desired
...@@ -264,14 +234,10 @@ class ModuleGroup: ...@@ -264,14 +234,10 @@ class ModuleGroup:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
def _offload_to_memory(self): def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None: if self.stream is not None:
if not self.record_stream: if not self.record_stream:
torch_accelerator_module.current_stream().synchronize() self._torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
param.data = self.cpu_param_dict[param] param.data = self.cpu_param_dict[param]
...@@ -282,15 +248,23 @@ class ModuleGroup: ...@@ -282,15 +248,23 @@ class ModuleGroup:
else: else:
for group_module in self.modules: for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking) group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters: for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()
@torch.compiler.disable() @torch.compiler.disable()
def offload_(self): def offload_(self):
r"""Offloads the group of modules to the offload_device.""" r"""Offloads the group of parameters to the offload_device."""
if self.offload_to_disk_path: if self.offload_to_disk_path:
self._offload_to_disk() self._offload_to_disk()
else: else:
...@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook): ...@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False _is_stateful = False
def __init__( def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
self.group = group self.group = group
self.next_group = next_group self.next_group: Optional[ModuleGroup] = None
self.config = config self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
...@@ -459,8 +431,8 @@ class LayerExecutionTrackerHook(ModelHook): ...@@ -459,8 +431,8 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading( def apply_group_offloading(
module: torch.nn.Module, module: torch.nn.Module,
onload_device: torch.device, onload_device: Union[str, torch.device],
offload_device: torch.device = torch.device("cpu"), offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level", offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
...@@ -546,6 +518,8 @@ def apply_group_offloading( ...@@ -546,6 +518,8 @@ def apply_group_offloading(
``` ```
""" """
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type) offload_type = GroupOffloadingType(offload_type)
stream = None stream = None
...@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf ...@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups # Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups): for i, group in enumerate(matched_module_groups):
for group_module in group.modules: for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None, config=config) _apply_group_offloading_hook(group_module, group, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not # when the forward pass of this module is called. This is because the top-level module is not
...@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf ...@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id=f"{module.__class__.__name__}_unmatched_group", group_id=f"{module.__class__.__name__}_unmatched_group",
) )
if config.stream is None: if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config) _apply_group_offloading_hook(module, unmatched_group, config=config)
else: else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
...@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff ...@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True, onload_self=True,
group_id=name, group_id=name,
) )
_apply_group_offloading_hook(submodule, group, None, config=config) _apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name) modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
...@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff ...@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True, onload_self=True,
group_id=name, group_id=name,
) )
_apply_group_offloading_hook(parent_module, group, None, config=config) _apply_group_offloading_hook(parent_module, group, config=config)
if config.stream is not None: if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
...@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff ...@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True, onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF, group_id=_GROUP_ID_LAZY_LEAF,
) )
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_hook( def _apply_group_offloading_hook(
module: torch.nn.Module, module: torch.nn.Module,
group: ModuleGroup, group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*, *,
config: GroupOffloadingConfig, config: GroupOffloadingConfig,
) -> None: ) -> None:
...@@ -777,14 +750,13 @@ def _apply_group_offloading_hook( ...@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook. # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None: if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config) hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING) registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook( def _apply_lazy_group_offloading_hook(
module: torch.nn.Module, module: torch.nn.Module,
group: ModuleGroup, group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*, *,
config: GroupOffloadingConfig, config: GroupOffloadingConfig,
) -> None: ) -> None:
...@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook( ...@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook. # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None: if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config) hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING) registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
......
...@@ -17,7 +17,9 @@ import gc ...@@ -17,7 +17,9 @@ import gc
import unittest import unittest
import torch import torch
from parameterized import parameterized
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger from diffusers.utils import get_logger
...@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin): ...@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
return x return x
# Test for https://github.com/huggingface/diffusers/pull/12077
class DummyModelWithLayerNorm(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.layer_norm(x)
x = self.linear_2(x)
return x
class DummyPipeline(DiffusionPipeline): class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model" model_cpu_offload_seq = "model"
...@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline): ...@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
return x return x
class LayerOutputTrackerHook(ModelHook):
def __init__(self):
super().__init__()
self.outputs = []
def post_forward(self, module, output):
self.outputs.append(output)
return output
@require_torch_accelerator @require_torch_accelerator
class GroupOffloadTests(unittest.TestCase): class GroupOffloadTests(unittest.TestCase):
in_features = 64 in_features = 64
...@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase): ...@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self): def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]: if torch.device(torch_device).type not in ["cuda", "xpu"]:
return return
model = DummyModelWithMultipleBlocks( model = DummyModelWithMultipleBlocks(
in_features=self.in_features, in_features=self.in_features,
hidden_features=self.hidden_features, hidden_features=self.hidden_features,
...@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase): ...@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
with context: with context:
model(self.input) model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
for name, module in model.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerOutputTrackerHook()
registry.register_hook(hook, "layer_output_tracker")
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
model = DummyModelWithLayerNorm(128, 256, 128, 2)
model.load_state_dict(model_ref.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model)
x = torch.randn(2, 128).to(torch_device)
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
ref_outputs = (
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
)
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
cumulated_absmax = 0.0
for i in range(len(outputs)):
diff = ref_outputs[0] - outputs[i]
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment