Unverified Commit f238cb07 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

cpu_offload: remove all hooks before offload (#7448)



* add remove_all_hooks

* a few more fix and tests

* up

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* split tests

* add

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent d78acded
...@@ -371,9 +371,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -371,9 +371,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False return False
return hasattr(module, "_hf_hook") and not isinstance( return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
)
def module_is_offloaded(module): def module_is_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
...@@ -939,6 +937,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -939,6 +937,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def remove_all_hooks(self):
r"""
Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
"""
for _, model in self.components.items():
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook)
accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload)
self._all_hooks = []
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r""" r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
...@@ -963,6 +971,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -963,6 +971,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
else: else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
self.remove_all_hooks()
torch_device = torch.device(device) torch_device = torch.device(device)
device_index = torch_device.index device_index = torch_device.index
...@@ -979,15 +989,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -979,15 +989,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device = torch.device(f"{device_type}:{self._offload_gpu_id}") device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device self._offload_device = device
if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True)
self.to("cpu", silence_dtype_warnings=True) device_mod = getattr(torch, device.type, None)
device_mod = getattr(torch, self.device.type, None) if hasattr(device_mod, "empty_cache") and device_mod.is_available():
if hasattr(device_mod, "empty_cache") and device_mod.is_available(): device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
self._all_hooks = []
hook = None hook = None
for model_str in self.model_cpu_offload_seq.split("->"): for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None) model = all_model_components.pop(model_str, None)
...@@ -1021,11 +1029,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1021,11 +1029,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# `enable_model_cpu_offload` has not be called, so silently do nothing # `enable_model_cpu_offload` has not be called, so silently do nothing
return return
for hook in self._all_hooks:
# offload model and remove hook from model
hook.offload()
hook.remove()
# make sure the model is in the same state as before calling it # make sure the model is in the same state as before calling it
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
...@@ -1048,6 +1051,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1048,6 +1051,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
from accelerate import cpu_offload from accelerate import cpu_offload
else: else:
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
torch_device = torch.device(device) torch_device = torch.device(device)
device_index = torch_device.index device_index = torch_device.index
......
...@@ -1107,6 +1107,98 @@ class PipelineTesterMixin: ...@@ -1107,6 +1107,98 @@ class PipelineTesterMixin:
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
) )
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
import accelerate
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.set_progress_bar_config(disable=None)
pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]
pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload_twice = pipe(**inputs)[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max()
self.assertLess(
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
)
offloaded_modules = [
v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
]
(
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
)
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
(
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
)
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
)
def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
import accelerate
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.set_progress_bar_config(disable=None)
pipe.enable_sequential_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]
pipe.nable_sequential_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload_twice = pipe(**inputs)[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max()
self.assertLess(
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
)
offloaded_modules = [
v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
]
(
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
)
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
(
self.assertTrue(
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
),
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
)
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(), torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed", reason="XFormers attention is only available with CUDA and `xformers` installed",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment