"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "67d070749ae393a234470b6ef653821bb4f02cc6"
Unverified Commit 85a916bb authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

make group offloading work with disk/nvme transfers (#11682)

* start implementing disk offloading in group.

* delete diff file.

* updates.patch

* offload_to_disk_path

* check if safetensors already exist.

* add test and clarify.

* updates

* update todos.

* update more docs.

* update docs
parent 3287ce28
...@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o ...@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
</Tip> </Tip>
### Offloading to disk
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
## Layerwise casting ## Layerwise casting
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality. Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
import torch import torch
from ..utils import get_logger, is_accelerate_available from ..utils import get_logger, is_accelerate_available
...@@ -59,6 +61,7 @@ class ModuleGroup: ...@@ -59,6 +61,7 @@ class ModuleGroup:
record_stream: Optional[bool] = False, record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
onload_self: bool = True, onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
) -> None: ) -> None:
self.modules = modules self.modules = modules
self.offload_device = offload_device self.offload_device = offload_device
...@@ -72,7 +75,26 @@ class ModuleGroup: ...@@ -72,7 +75,26 @@ class ModuleGroup:
self.record_stream = record_stream self.record_stream = record_stream
self.onload_self = onload_self self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()
if self.stream is None and self.record_stream: if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.") raise ValueError("`record_stream` cannot be True when `stream` is None.")
...@@ -124,6 +146,30 @@ class ModuleGroup: ...@@ -124,6 +146,30 @@ class ModuleGroup:
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
if self.stream is not None: if self.stream is not None:
# Wait for previous Host->Device transfer to complete # Wait for previous Host->Device transfer to complete
self.stream.synchronize() self.stream.synchronize()
...@@ -169,6 +215,26 @@ class ModuleGroup: ...@@ -169,6 +215,26 @@ class ModuleGroup:
@torch.compiler.disable() @torch.compiler.disable()
def offload_(self): def offload_(self):
r"""Offloads the group of modules to the offload_device.""" r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
return
torch_accelerator_module = ( torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type) getattr(torch, torch.accelerator.current_accelerator().type)
...@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook): ...@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False _is_stateful = False
def __init__( def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
self,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
self.group = group self.group = group
self.next_group = next_group self.next_group = next_group
...@@ -363,6 +425,7 @@ def apply_group_offloading( ...@@ -363,6 +425,7 @@ def apply_group_offloading(
use_stream: bool = False, use_stream: bool = False,
record_stream: bool = False, record_stream: bool = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None: ) -> None:
r""" r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
...@@ -401,6 +464,9 @@ def apply_group_offloading( ...@@ -401,6 +464,9 @@ def apply_group_offloading(
offload_type (`str`, defaults to "block_level"): offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level". "block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
num_blocks_per_group (`int`, *optional*): num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level". offload_type="block_level".
...@@ -458,6 +524,7 @@ def apply_group_offloading( ...@@ -458,6 +524,7 @@ def apply_group_offloading(
num_blocks_per_group=num_blocks_per_group, num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device, offload_device=offload_device,
onload_device=onload_device, onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
record_stream=record_stream, record_stream=record_stream,
...@@ -468,6 +535,7 @@ def apply_group_offloading( ...@@ -468,6 +535,7 @@ def apply_group_offloading(
module=module, module=module,
offload_device=offload_device, offload_device=offload_device,
onload_device=onload_device, onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
record_stream=record_stream, record_stream=record_stream,
...@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level( ...@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False, record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None: ) -> None:
r""" r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
...@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level( ...@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
The module to which group offloading is applied. The module to which group offloading is applied.
offload_device (`torch.device`): offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU. The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`): onload_device (`torch.device`):
The device to which the group of modules are onloaded. The device to which the group of modules are onloaded.
non_blocking (`bool`): non_blocking (`bool`):
...@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level( ...@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
modules=current_modules, modules=current_modules,
offload_device=offload_device, offload_device=offload_device,
onload_device=onload_device, onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=current_modules[-1], offload_leader=current_modules[-1],
onload_leader=current_modules[0], onload_leader=current_modules[0],
non_blocking=non_blocking, non_blocking=non_blocking,
...@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level( ...@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
modules=unmatched_modules, modules=unmatched_modules,
offload_device=offload_device, offload_device=offload_device,
onload_device=onload_device, onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module, offload_leader=module,
onload_leader=module, onload_leader=module,
parameters=parameters, parameters=parameters,
...@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level( ...@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False, record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None: ) -> None:
r""" r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
...@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level( ...@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
The device to which the group of modules are offloaded. This should typically be the CPU. The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`): onload_device (`torch.device`):
The device to which the group of modules are onloaded. The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`): non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer. and data transfer.
...@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level( ...@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
modules=[submodule], modules=[submodule],
offload_device=offload_device, offload_device=offload_device,
onload_device=onload_device, onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=submodule, offload_leader=submodule,
onload_leader=submodule, onload_leader=submodule,
non_blocking=non_blocking, non_blocking=non_blocking,
...@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level( ...@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
onload_device=onload_device, onload_device=onload_device,
offload_leader=parent_module, offload_leader=parent_module,
onload_leader=parent_module, onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path,
parameters=parameters, parameters=parameters,
buffers=buffers, buffers=buffers,
non_blocking=non_blocking, non_blocking=non_blocking,
...@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level( ...@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
modules=[], modules=[],
offload_device=offload_device, offload_device=offload_device,
onload_device=onload_device, onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module, offload_leader=module,
onload_leader=module, onload_leader=module,
parameters=None, parameters=None,
......
...@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
use_stream: bool = False, use_stream: bool = False,
record_stream: bool = False, record_stream: bool = False,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
) -> None: ) -> None:
r""" r"""
Activates group offloading for the current model. Activates group offloading for the current model.
...@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues." f"open an issue at https://github.com/huggingface/diffusers/issues."
) )
apply_group_offloading( apply_group_offloading(
self, module=self,
onload_device, onload_device=onload_device,
offload_device, offload_device=offload_device,
offload_type, offload_type=offload_type,
num_blocks_per_group, num_blocks_per_group=num_blocks_per_group,
non_blocking, non_blocking=non_blocking,
use_stream, use_stream=use_stream,
record_stream, record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
) )
def save_pretrained( def save_pretrained(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
import gc import gc
import glob
import inspect import inspect
import json import json
import os import os
...@@ -1693,6 +1694,35 @@ class ModelTesterMixin: ...@@ -1693,6 +1694,35 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0] _ = model(**inputs_dict)[0]
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5): def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args: if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict) model = self.model_class(**self.init_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment