Unverified Commit 7edace9a authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

fix CPU offloading related fail cases on XPU (#11288)



* fix CPU offloading related fail cases on XPU
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix style
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* Apply style fixes

* trigger tests

* test_pipe_same_device_id_offload

---------
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 6e80d240
...@@ -65,7 +65,7 @@ from ..utils import ( ...@@ -65,7 +65,7 @@ from ..utils import (
numpy_to_pil, numpy_to_pil,
) )
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module from ..utils.torch_utils import get_device, is_compiled_module
if is_torch_npu_available(): if is_torch_npu_available():
...@@ -1084,19 +1084,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1084,19 +1084,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
accelerate.hooks.remove_hook_from_module(model, recurse=True) accelerate.hooks.remove_hook_from_module(model, recurse=True)
self._all_hooks = [] self._all_hooks = []
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
of the `unet`.
Arguments: Arguments:
gpu_id (`int`, *optional*): gpu_id (`int`, *optional*):
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
device (`torch.Device` or `str`, *optional*, defaults to "cuda"): device (`torch.Device` or `str`, *optional*, defaults to None):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda". automatically detect the available accelerator and use.
""" """
self._maybe_raise_error_if_group_offload_active(raise_error=True) self._maybe_raise_error_if_group_offload_active(raise_error=True)
...@@ -1118,6 +1119,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1118,6 +1119,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self.remove_all_hooks() self.remove_all_hooks()
if device is None:
device = get_device()
if device == "cpu":
raise RuntimeError("`enable_model_cpu_offload` requires accelerator, but not found")
torch_device = torch.device(device) torch_device = torch.device(device)
device_index = torch_device.index device_index = torch_device.index
...@@ -1196,20 +1202,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1196,20 +1202,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# make sure the model is in the same state as before calling it # make sure the model is in the same state as before calling it
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` and then moved to `torch.device('meta')` and loaded to accelerator only when their specific submodule has its
method called. Offloading happens on a submodule basis. Memory savings are higher than with `forward` method called. Offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower. `enable_model_cpu_offload`, but performance is lower.
Arguments: Arguments:
gpu_id (`int`, *optional*): gpu_id (`int`, *optional*):
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
device (`torch.Device` or `str`, *optional*, defaults to "cuda"): device (`torch.Device` or `str`, *optional*, defaults to None):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda". automatically detect the available accelerator and use.
""" """
self._maybe_raise_error_if_group_offload_active(raise_error=True) self._maybe_raise_error_if_group_offload_active(raise_error=True)
...@@ -1225,6 +1231,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1225,6 +1231,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
) )
if device is None:
device = get_device()
if device == "cpu":
raise RuntimeError("`enable_sequential_cpu_offload` requires accelerator, but not found")
torch_device = torch.device(device) torch_device = torch.device(device)
device_index = torch_device.index device_index = torch_device.index
......
...@@ -159,3 +159,12 @@ def get_torch_cuda_device_capability(): ...@@ -159,3 +159,12 @@ def get_torch_cuda_device_capability():
return float(compute_capability) return float(compute_capability)
else: else:
return None return None
def get_device():
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
else:
return "cpu"
...@@ -1816,7 +1816,12 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1816,7 +1816,12 @@ class PipelineFastTests(unittest.TestCase):
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
sd.enable_model_cpu_offload(gpu_id=5) # `enable_model_cpu_offload` detects device type when not passed
# `enable_model_cpu_offload` raises ValueError if detected device is `cpu`
# This test only checks whether `_offload_gpu_id` is set correctly
# So the device passed can be any supported `torch.device` type
# This allows us to keep the test under `PipelineFastTests`
sd.enable_model_cpu_offload(gpu_id=5, device="cuda")
assert sd._offload_gpu_id == 5 assert sd._offload_gpu_id == 5
sd.maybe_free_model_hooks() sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5 assert sd._offload_gpu_id == 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment