Unverified Commit 2d3d376b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Fix unique memory address when doing group-offloading with disk (#11767)



* fix memory address problem

* add more tests

* updates

* updates

* update

* _group_id = group_id

* update

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* fix

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent db715e2c
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import os import os
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
...@@ -37,7 +38,7 @@ logger = get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,7 +38,7 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
_GROUP_OFFLOADING = "group_offloading" _GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = ( _SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
...@@ -82,6 +83,7 @@ class ModuleGroup: ...@@ -82,6 +83,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
onload_self: bool = True, onload_self: bool = True,
offload_to_disk_path: Optional[str] = None, offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
) -> None: ) -> None:
self.modules = modules self.modules = modules
self.offload_device = offload_device self.offload_device = offload_device
...@@ -100,7 +102,10 @@ class ModuleGroup: ...@@ -100,7 +102,10 @@ class ModuleGroup:
self._is_offloaded_to_disk = False self._is_offloaded_to_disk = False
if self.offload_to_disk_path: if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
all_tensors = [] all_tensors = []
for module in self.modules: for module in self.modules:
...@@ -609,6 +614,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf ...@@ -609,6 +614,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for i in range(0, len(submodule), config.num_blocks_per_group): for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group] current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup( group = ModuleGroup(
modules=current_modules, modules=current_modules,
offload_device=config.offload_device, offload_device=config.offload_device,
...@@ -621,6 +627,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf ...@@ -621,6 +627,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
record_stream=config.record_stream, record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
group_id=group_id,
) )
matched_module_groups.append(group) matched_module_groups.append(group)
for j in range(i, i + len(current_modules)): for j in range(i, i + len(current_modules)):
...@@ -655,6 +662,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf ...@@ -655,6 +662,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
stream=None, stream=None,
record_stream=False, record_stream=False,
onload_self=True, onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
) )
if config.stream is None: if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config) _apply_group_offloading_hook(module, unmatched_group, None, config=config)
...@@ -686,6 +694,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff ...@@ -686,6 +694,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream, record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
group_id=name,
) )
_apply_group_offloading_hook(submodule, group, None, config=config) _apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name) modules_with_group_offloading.add(name)
...@@ -732,6 +741,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff ...@@ -732,6 +741,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream, record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
group_id=name,
) )
_apply_group_offloading_hook(parent_module, group, None, config=config) _apply_group_offloading_hook(parent_module, group, None, config=config)
...@@ -753,6 +763,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff ...@@ -753,6 +763,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=False, record_stream=False,
low_cpu_mem_usage=config.low_cpu_mem_usage, low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True, onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
) )
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
...@@ -873,6 +884,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: ...@@ -873,6 +884,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
raise ValueError("Group offloading is not enabled for the provided module.") raise ValueError("Group offloading is not enabled for the provided module.")
def _compute_group_hash(group_id):
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
# first 16 characters for a reasonably short but unique name
return hashed_id[:16]
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None: def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r""" r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
......
import functools import functools
import glob
import importlib import importlib
import importlib.metadata import importlib.metadata
import inspect import inspect
...@@ -18,7 +19,7 @@ from collections import UserDict ...@@ -18,7 +19,7 @@ from collections import UserDict
from contextlib import contextmanager from contextlib import contextmanager
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -1392,6 +1393,103 @@ if TYPE_CHECKING: ...@@ -1392,6 +1393,103 @@ if TYPE_CHECKING:
else: else:
DevicePropertiesUserDict = UserDict DevicePropertiesUserDict = UserDict
if is_torch_available():
from diffusers.hooks.group_offloading import (
_GROUP_ID_LAZY_LEAF,
_SUPPORTED_PYTORCH_LAYERS,
_compute_group_hash,
_find_parent_module_in_module_dict,
_gather_buffers_with_no_group_offloading_parent,
_gather_parameters_with_no_group_offloading_parent,
)
def _get_expected_safetensors_files(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
) -> Set[str]:
expected_files = set()
def get_hashed_filename(group_id: str) -> str:
short_hash = _compute_group_hash(group_id)
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
# Handle groups of ModuleList and Sequential blocks
unmatched_modules = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append(module)
continue
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
# Handle the group for unmatched top-level modules and parameters
for module in unmatched_modules:
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
elif offload_type == "leaf_level":
# Handle leaf-level module groups
for name, submodule in module.named_modules():
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
# These groups will always have parameters, so a file is expected
expected_files.add(get_hashed_filename(name))
# Handle groups for non-leaf parameters/buffers
modules_with_group_offloading = {
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
}
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
all_orphans = parameters + buffers
if all_orphans:
parent_to_tensors = {}
module_dict = dict(module.named_modules())
for tensor_name, _ in all_orphans:
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
if parent_name not in parent_to_tensors:
parent_to_tensors[parent_name] = []
parent_to_tensors[parent_name].append(tensor_name)
for parent_name in parent_to_tensors:
# A file is expected for each parent that gathers orphaned tensors
expected_files.add(get_hashed_filename(parent_name))
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
return expected_files
def _check_safetensors_serialization(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
) -> bool:
if not os.path.isdir(offload_to_disk_path):
return False, None, None
expected_files = _get_expected_safetensors_files(
module, offload_to_disk_path, offload_type, num_blocks_per_group
)
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
missing_files = expected_files - actual_files
extra_files = actual_files - expected_files
is_correct = not missing_files and not extra_files
return is_correct, extra_files, missing_files
class Expectations(DevicePropertiesUserDict): class Expectations(DevicePropertiesUserDict):
def get_expectation(self) -> Any: def get_expectation(self) -> Any:
......
...@@ -61,6 +61,7 @@ from diffusers.utils import ( ...@@ -61,6 +61,7 @@ from diffusers.utils import (
from diffusers.utils.hub_utils import _add_variant from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
_check_safetensors_serialization,
backend_empty_cache, backend_empty_cache,
backend_max_memory_allocated, backend_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
...@@ -1702,18 +1703,43 @@ class ModelTesterMixin: ...@@ -1702,18 +1703,43 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0] _ = model(**inputs_dict)[0]
@parameterized.expand([(False, "block_level"), (True, "leaf_level")]) @parameterized.expand([("block_level", False), ("leaf_level", True)])
@require_torch_accelerator @require_torch_accelerator
@torch.no_grad() @torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type): @torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
if not self.model_class._supports_group_offloading: if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.") pytest.skip("Model does not support group offloading.")
torch.manual_seed(0) def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.eval() model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload( model.enable_group_offload(
torch_device, torch_device,
...@@ -1724,8 +1750,25 @@ class ModelTesterMixin: ...@@ -1724,8 +1750,25 @@ class ModelTesterMixin:
**additional_kwargs, **additional_kwargs,
) )
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.") self.assertTrue(has_safetensors, "No safetensors found in the directory.")
_ = model(**inputs_dict)[0]
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
def test_auto_model(self, expected_max_diff=5e-5): def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args: if self.forward_requires_fresh_args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment