Unverified Commit 5bb38586 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] fix offload behaviour when device_map is enabled. (#7919)

fix offload behaviour when device_map is enabled.
parent ec9e8813
...@@ -363,7 +363,7 @@ class LoraLoaderMixin: ...@@ -363,7 +363,7 @@ class LoraLoaderMixin:
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
if _pipeline is not None: if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items(): for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload: if not is_model_cpu_offload:
......
...@@ -419,19 +419,20 @@ class TextualInversionLoaderMixin: ...@@ -419,19 +419,20 @@ class TextualInversionLoaderMixin:
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
for _, component in self.components.items(): if self.hf_device_map is None:
if isinstance(component, nn.Module): for _, component in self.components.items():
if hasattr(component, "_hf_hook"): if isinstance(component, nn.Module):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) if hasattr(component, "_hf_hook"):
is_sequential_cpu_offload = ( is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) is_sequential_cpu_offload = (
or hasattr(component._hf_hook, "hooks") isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) or hasattr(component._hf_hook, "hooks")
) and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
logger.info( )
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." logger.info(
) "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
remove_hook_from_module(component, recurse=is_sequential_cpu_offload) )
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# 7.2 save expected device and dtype # 7.2 save expected device and dtype
device = text_encoder.device device = text_encoder.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment