Unverified Commit 85a916bb authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

make group offloading work with disk/nvme transfers (#11682)

* start implementing disk offloading in group.

* delete diff file.

* updates.patch

* offload_to_disk_path

* check if safetensors already exist.

* add test and clarify.

* updates

* update todos.

* update more docs.

* update docs
parent 3287ce28
......@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
</Tip>
### Offloading to disk
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
## Layerwise casting
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
......
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
......@@ -59,6 +61,7 @@ class ModuleGroup:
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
......@@ -72,7 +75,26 @@ class ModuleGroup:
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()
if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
......@@ -124,6 +146,30 @@ class ModuleGroup:
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
......@@ -169,6 +215,26 @@ class ModuleGroup:
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
return
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
......@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(
self,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
self.group = group
self.next_group = next_group
......@@ -363,6 +425,7 @@ def apply_group_offloading(
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
......@@ -401,6 +464,9 @@ def apply_group_offloading(
offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
......@@ -458,6 +524,7 @@ def apply_group_offloading(
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
......@@ -468,6 +535,7 @@ def apply_group_offloading(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
......@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
......@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
......@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
......@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
......@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
......@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
......@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
......@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
onload_device=onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
......@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
......
......@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
Activates group offloading for the current model.
......@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
module=self,
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
use_stream=use_stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
)
def save_pretrained(
......
......@@ -15,6 +15,7 @@
import copy
import gc
import glob
import inspect
import json
import os
......@@ -1693,6 +1694,35 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment