Unverified Commit f20b83a0 authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable cpu offloading of new pipelines on XPU & use device agnostic empty to...


enable cpu offloading of new pipelines on XPU & use device agnostic empty to make pipelines work on XPU (#11671)

* commit 1
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* patch 2
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* Update pipeline_pag_sana.py

* Update pipeline_sana.py

* Update pipeline_sana_controlnet.py

* Update pipeline_sana_sprint_img2img.py

* Update pipeline_sana_sprint.py

* fix style
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix fat-thumb while merge conflict
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix ci issues
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
Co-authored-by: default avatarIlyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
parent ee40088f
......@@ -49,7 +49,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .pag_utils import PAGMixin
......@@ -716,7 +716,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
image = image.to(device=device, dtype=dtype)
......
......@@ -67,7 +67,7 @@ from ..utils import (
numpy_to_pil,
)
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import get_device, is_compiled_module
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
if is_torch_npu_available():
......@@ -1203,9 +1203,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self._offload_device = device
self.to("cpu", silence_dtype_warnings=True)
device_mod = getattr(torch, device.type, None)
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
empty_device_cache(device.type)
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
......@@ -1315,10 +1313,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self._offload_device = device
if self.device.type != "cpu":
orig_device_type = self.device.type
self.to("cpu", silence_dtype_warnings=True)
device_mod = getattr(torch, self.device.type, None)
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
empty_device_cache(orig_device_type)
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
......
......@@ -38,7 +38,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_512_BIN,
......@@ -982,9 +982,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
oom_error = (
torch.OutOfMemoryError
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
......
......@@ -38,7 +38,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_512_BIN,
......@@ -1078,9 +1078,15 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
oom_error = (
torch.OutOfMemoryError
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
......
......@@ -38,7 +38,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
from .pipeline_output import SanaPipelineOutput
......@@ -864,9 +864,15 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
oom_error = (
torch.OutOfMemoryError
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
......
......@@ -39,7 +39,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
from .pipeline_output import SanaPipelineOutput
......@@ -952,9 +952,15 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
oom_error = (
torch.OutOfMemoryError
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
......
......@@ -125,7 +125,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
......@@ -135,7 +135,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
......
......@@ -53,6 +53,7 @@ from ...schedulers import (
)
from ...utils import is_accelerate_available, logging
from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ...utils.torch_utils import get_device
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
......@@ -1272,7 +1273,7 @@ def download_from_original_stable_diffusion_ckpt(
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device()
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else:
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
......@@ -1842,7 +1843,7 @@ def download_controlnet_from_original_ckpt(
checkpoint[key] = f.get_tensor(key)
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device()
checkpoint = torch.load(checkpoint_path, map_location=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
......
......@@ -50,7 +50,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import StableDiffusionXLPipelineOutput
......@@ -704,7 +704,7 @@ class StableDiffusionXLImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
image = image.to(device=device, dtype=dtype)
......
......@@ -23,7 +23,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionSafetyChecker
......@@ -760,7 +760,7 @@ class TextToVideoZeroPipeline(
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if output_type == "latent":
image = latents
......
......@@ -113,7 +113,7 @@ class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
......@@ -123,7 +123,7 @@ class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
......
......@@ -23,6 +23,7 @@ from packaging import version
from . import logging
from .import_utils import is_peft_available, is_peft_version, is_torch_available
from .torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
......@@ -98,8 +99,7 @@ def recurse_remove_peft_layers(model):
setattr(model, name, new_module)
del module
if torch.cuda.is_available():
torch.cuda.empty_cache()
empty_device_cache()
return model
......
......@@ -172,3 +172,10 @@ def get_device():
return "xpu"
else:
return "cpu"
def empty_device_cache(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment