Unverified Commit f20b83a0 authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable cpu offloading of new pipelines on XPU & use device agnostic empty to...


enable cpu offloading of new pipelines on XPU & use device agnostic empty to make pipelines work on XPU (#11671)

* commit 1
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* patch 2
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* Update pipeline_pag_sana.py

* Update pipeline_sana.py

* Update pipeline_sana_controlnet.py

* Update pipeline_sana_sprint_img2img.py

* Update pipeline_sana_sprint.py

* fix style
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix fat-thumb while merge conflict
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix ci issues
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
Co-authored-by: default avatarIlyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
parent ee40088f
......@@ -41,7 +41,7 @@ from ...utils import (
replace_example_docstring,
)
from ...utils.import_utils import is_transformers_version
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
......@@ -267,9 +267,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
device_mod = getattr(torch, device.type, None)
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
empty_device_cache(device.type)
model_sequence = [
self.text_encoder.text_model,
......
......@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
Parameters:
- model_path: Path to the directory containing model files.
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
- dtype: Data type (e.g., torch.float32) for model inference.
Returns:
......
......@@ -37,7 +37,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -1339,7 +1339,7 @@ class StableDiffusionControlNetPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
......@@ -36,7 +36,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -1311,7 +1311,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
......@@ -38,7 +38,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -1500,7 +1500,7 @@ class StableDiffusionControlNetInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
......@@ -51,7 +51,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -1858,7 +1858,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
......@@ -1465,7 +1465,11 @@ class StableDiffusionXLControlNetPipeline(
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
......
......@@ -53,7 +53,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -921,7 +921,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
image = image.to(device=device, dtype=dtype)
......@@ -1632,7 +1632,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
......
......@@ -51,7 +51,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -1766,7 +1766,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
......@@ -53,7 +53,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -876,7 +876,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
image = image.to(device=device, dtype=dtype)
......@@ -1574,7 +1574,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
......
......@@ -36,7 +36,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -853,7 +853,7 @@ class StableDiffusionControlNetXSPipeline(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if is_controlnet_compiled and is_torch_higher_equal_2_1:
if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
......@@ -902,7 +902,7 @@ class StableDiffusionControlNetXSPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
......@@ -193,7 +193,7 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
......@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......
......@@ -179,7 +179,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......@@ -407,7 +407,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
......@@ -417,7 +417,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......@@ -656,7 +656,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
......
......@@ -25,7 +25,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import KolorsPipelineOutput
from .text_encoder import ChatGLMModel
......@@ -618,7 +618,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
image = image.to(device=device, dtype=dtype)
......
......@@ -35,7 +35,7 @@ from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
......@@ -397,20 +397,22 @@ class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusi
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
`forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
device_type = get_device()
device = torch.device(f"{device_type}:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
empty_device_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [
self.text_encoder.text_model,
......
......@@ -36,7 +36,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -1228,7 +1228,11 @@ class StableDiffusionControlNetPAGPipeline(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
......@@ -1309,7 +1313,7 @@ class StableDiffusionControlNetPAGPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
......@@ -37,7 +37,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -1521,7 +1521,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......
......@@ -1498,7 +1498,11 @@ class StableDiffusionXLControlNetPAGPipeline(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
......
......@@ -52,7 +52,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .pag_utils import PAGMixin
......@@ -926,7 +926,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
image = image.to(device=device, dtype=dtype)
......@@ -1648,7 +1648,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
......
......@@ -35,7 +35,7 @@ from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_512_BIN,
......@@ -917,9 +917,15 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
oom_error = (
torch.OutOfMemoryError
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment