Unverified Commit 9254d1f3 authored by Disty0's avatar Disty0 Committed by GitHub
Browse files

Pass device to enable_model_cpu_offload in maybe_free_model_hooks (#6937)

parent e1bdcc7a
...@@ -1423,6 +1423,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1423,6 +1423,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_type = torch_device.type device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}") device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device
if self.device.type != "cpu": if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True) self.to("cpu", silence_dtype_warnings=True)
...@@ -1472,7 +1473,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1472,7 +1473,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
hook.remove() hook.remove()
# make sure the model is in the same state as before calling it # make sure the model is in the same state as before calling it
self.enable_model_cpu_offload() self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r""" r"""
...@@ -1508,6 +1509,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1508,6 +1509,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_type = torch_device.type device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}") device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device
if self.device.type != "cpu": if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True) self.to("cpu", silence_dtype_warnings=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment