Unverified Commit f20b83a0 authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable cpu offloading of new pipelines on XPU & use device agnostic empty to...


enable cpu offloading of new pipelines on XPU & use device agnostic empty to make pipelines work on XPU (#11671)

* commit 1
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* patch 2
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* Update pipeline_pag_sana.py

* Update pipeline_sana.py

* Update pipeline_sana_controlnet.py

* Update pipeline_sana_sprint_img2img.py

* Update pipeline_sana_sprint.py

* fix style
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix fat-thumb while merge conflict
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix ci issues
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
Co-authored-by: default avatarIlyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
parent ee40088f
...@@ -41,7 +41,7 @@ from ...utils import ( ...@@ -41,7 +41,7 @@ from ...utils import (
replace_example_docstring, replace_example_docstring,
) )
from ...utils.import_utils import is_transformers_version from ...utils.import_utils import is_transformers_version
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
...@@ -267,9 +267,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -267,9 +267,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
if self.device.type != "cpu": if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True) self.to("cpu", silence_dtype_warnings=True)
device_mod = getattr(torch, device.type, None) empty_device_cache(device.type)
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [ model_sequence = [
self.text_encoder.text_model, self.text_encoder.text_model,
......
...@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype): ...@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
Parameters: Parameters:
- model_path: Path to the directory containing model files. - model_path: Path to the directory containing model files.
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded. - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
- dtype: Data type (e.g., torch.float32) for model inference. - dtype: Data type (e.g., torch.float32) for model inference.
Returns: Returns:
......
...@@ -37,7 +37,7 @@ from ...utils import ( ...@@ -37,7 +37,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -1339,7 +1339,7 @@ class StableDiffusionControlNetPipeline( ...@@ -1339,7 +1339,7 @@ class StableDiffusionControlNetPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
...@@ -36,7 +36,7 @@ from ...utils import ( ...@@ -36,7 +36,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -1311,7 +1311,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1311,7 +1311,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
...@@ -38,7 +38,7 @@ from ...utils import ( ...@@ -38,7 +38,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -1500,7 +1500,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1500,7 +1500,7 @@ class StableDiffusionControlNetInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
...@@ -51,7 +51,7 @@ from ...utils import ( ...@@ -51,7 +51,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -1858,7 +1858,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1858,7 +1858,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -1465,7 +1465,11 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1465,7 +1465,11 @@ class StableDiffusionXLControlNetPipeline(
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
......
...@@ -53,7 +53,7 @@ from ...utils import ( ...@@ -53,7 +53,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -921,7 +921,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -921,7 +921,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled # Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu") self.text_encoder_2.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
...@@ -1632,7 +1632,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1632,7 +1632,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
......
...@@ -51,7 +51,7 @@ from ...utils import ( ...@@ -51,7 +51,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -1766,7 +1766,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1766,7 +1766,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -53,7 +53,7 @@ from ...utils import ( ...@@ -53,7 +53,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -876,7 +876,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -876,7 +876,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled # Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu") self.text_encoder_2.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
...@@ -1574,7 +1574,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1574,7 +1574,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
......
...@@ -36,7 +36,7 @@ from ...utils import ( ...@@ -36,7 +36,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -853,7 +853,7 @@ class StableDiffusionControlNetXSPipeline( ...@@ -853,7 +853,7 @@ class StableDiffusionControlNetXSPipeline(
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if is_controlnet_compiled and is_torch_higher_equal_2_1: if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
...@@ -902,7 +902,7 @@ class StableDiffusionControlNetXSPipeline( ...@@ -902,7 +902,7 @@ class StableDiffusionControlNetXSPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
...@@ -193,7 +193,7 @@ class KandinskyCombinedPipeline(DiffusionPipeline): ...@@ -193,7 +193,7 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
...@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline): ...@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
...@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline): ...@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......
...@@ -179,7 +179,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): ...@@ -179,7 +179,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
...@@ -407,7 +407,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): ...@@ -407,7 +407,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
...@@ -417,7 +417,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): ...@@ -417,7 +417,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
...@@ -656,7 +656,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): ...@@ -656,7 +656,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......
...@@ -25,7 +25,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel ...@@ -25,7 +25,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import KolorsPipelineOutput from .pipeline_output import KolorsPipelineOutput
from .text_encoder import ChatGLMModel from .text_encoder import ChatGLMModel
...@@ -618,7 +618,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu ...@@ -618,7 +618,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
# Offload text encoder if `enable_model_cpu_offload` was enabled # Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu") self.text_encoder_2.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
......
...@@ -35,7 +35,7 @@ from ...utils import ( ...@@ -35,7 +35,7 @@ from ...utils import (
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
...@@ -397,20 +397,22 @@ class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusi ...@@ -397,20 +397,22 @@ class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusi
def enable_model_cpu_offload(self, gpu_id=0): def enable_model_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
of the `unet`.
""" """
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook from accelerate import cpu_offload_with_hook
else: else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}") device_type = get_device()
device = torch.device(f"{device_type}:{gpu_id}")
if self.device.type != "cpu": if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True) self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) empty_device_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [ model_sequence = [
self.text_encoder.text_model, self.text_encoder.text_model,
......
...@@ -36,7 +36,7 @@ from ...utils import ( ...@@ -36,7 +36,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -1228,7 +1228,11 @@ class StableDiffusionControlNetPAGPipeline( ...@@ -1228,7 +1228,11 @@ class StableDiffusionControlNetPAGPipeline(
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
...@@ -1309,7 +1313,7 @@ class StableDiffusionControlNetPAGPipeline( ...@@ -1309,7 +1313,7 @@ class StableDiffusionControlNetPAGPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
...@@ -37,7 +37,7 @@ from ...utils import ( ...@@ -37,7 +37,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -1521,7 +1521,7 @@ class StableDiffusionControlNetPAGInpaintPipeline( ...@@ -1521,7 +1521,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
...@@ -1498,7 +1498,11 @@ class StableDiffusionXLControlNetPAGPipeline( ...@@ -1498,7 +1498,11 @@ class StableDiffusionXLControlNetPAGPipeline(
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
......
...@@ -52,7 +52,7 @@ from ...utils import ( ...@@ -52,7 +52,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
...@@ -926,7 +926,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline( ...@@ -926,7 +926,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled # Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu") self.text_encoder_2.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
...@@ -1648,7 +1648,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline( ...@@ -1648,7 +1648,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
self.controlnet.to("cpu") self.controlnet.to("cpu")
torch.cuda.empty_cache() empty_device_cache()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
......
...@@ -35,7 +35,7 @@ from ...utils import ( ...@@ -35,7 +35,7 @@ from ...utils import (
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..pixart_alpha.pipeline_pixart_alpha import ( from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_512_BIN, ASPECT_RATIO_512_BIN,
...@@ -917,9 +917,15 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -917,9 +917,15 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
image = latents image = latents
else: else:
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
oom_error = (
torch.OutOfMemoryError
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
try: try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e: except oom_error as e:
warnings.warn( warnings.warn(
f"{e}. \n" f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n" f"Try to use VAE tiling for large images. For example: \n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment