"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "369de832cdca7680c8f50ba196d39172a895fcad"
Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -41,6 +42,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -41,6 +42,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1294,6 +1302,9 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1294,6 +1302,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -32,6 +32,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -32,6 +32,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -43,6 +44,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -43,6 +44,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1476,6 +1484,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1476,6 +1484,9 @@ class StableDiffusionControlNetInpaintPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -60,6 +60,16 @@ if is_invisible_watermark_available(): ...@@ -60,6 +60,16 @@ if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1833,6 +1843,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1833,6 +1843,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae() self.upcast_vae()
......
...@@ -62,6 +62,16 @@ if is_invisible_watermark_available(): ...@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1552,6 +1562,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1552,6 +1562,9 @@ class StableDiffusionXLControlNetPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -62,6 +62,16 @@ if is_invisible_watermark_available(): ...@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1612,6 +1622,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1612,6 +1622,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -60,6 +60,16 @@ if is_invisible_watermark_available(): ...@@ -60,6 +60,16 @@ if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1759,6 +1769,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1759,6 +1769,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae() self.upcast_vae()
......
...@@ -60,6 +60,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu ...@@ -60,6 +60,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
if is_invisible_watermark_available(): if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1458,6 +1469,9 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1458,6 +1469,9 @@ class StableDiffusionXLControlNetUnionPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -61,6 +61,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu ...@@ -61,6 +61,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
if is_invisible_watermark_available(): if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1577,6 +1588,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1577,6 +1588,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -41,6 +42,13 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput ...@@ -41,6 +42,13 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -884,6 +892,9 @@ class StableDiffusionControlNetXSPipeline( ...@@ -884,6 +892,9 @@ class StableDiffusionControlNetXSPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -54,6 +54,16 @@ if is_invisible_watermark_available(): ...@@ -54,6 +54,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1078,6 +1088,9 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -1078,6 +1088,9 @@ class StableDiffusionXLControlNetXSPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# manually for max memory savings # manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae() self.upcast_vae()
......
...@@ -17,11 +17,18 @@ from typing import List, Optional, Tuple, Union ...@@ -17,11 +17,18 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import logging from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -146,6 +153,9 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -146,6 +153,9 @@ class DanceDiffusionPipeline(DiffusionPipeline):
# 2. compute previous audio sample: x_t -> t_t-1 # 2. compute previous audio sample: x_t -> t_t-1
audio = self.scheduler.step(model_output, t, audio).prev_sample audio = self.scheduler.step(model_output, t, audio).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
audio = audio.clamp(-1, 1).float().cpu().numpy() audio = audio.clamp(-1, 1).float().cpu().numpy()
audio = audio[:, :, :original_sample_size] audio = audio[:, :, :original_sample_size]
......
...@@ -17,10 +17,19 @@ from typing import List, Optional, Tuple, Union ...@@ -17,10 +17,19 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class DDIMPipeline(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
r""" r"""
Pipeline for image generation. Pipeline for image generation.
...@@ -143,6 +152,9 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -143,6 +152,9 @@ class DDIMPipeline(DiffusionPipeline):
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
......
...@@ -17,10 +17,19 @@ from typing import List, Optional, Tuple, Union ...@@ -17,10 +17,19 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class DDPMPipeline(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
r""" r"""
Pipeline for image generation. Pipeline for image generation.
...@@ -116,6 +125,9 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -116,6 +125,9 @@ class DDPMPipeline(DiffusionPipeline):
# 2. compute previous image: x_t -> x_t-1 # 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
......
...@@ -14,6 +14,7 @@ from ...utils import ( ...@@ -14,6 +14,7 @@ from ...utils import (
BACKENDS_MAPPING, BACKENDS_MAPPING,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -24,8 +25,16 @@ from .safety_checker import IFSafetyChecker ...@@ -24,8 +25,16 @@ from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker from .watermark import IFWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -735,6 +744,9 @@ class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): ...@@ -735,6 +744,9 @@ class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images) callback(i, t, intermediate_images)
if XLA_AVAILABLE:
xm.mark_step()
image = intermediate_images image = intermediate_images
if output_type == "pil": if output_type == "pil":
......
...@@ -17,6 +17,7 @@ from ...utils import ( ...@@ -17,6 +17,7 @@ from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -27,8 +28,16 @@ from .safety_checker import IFSafetyChecker ...@@ -27,8 +28,16 @@ from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker from .watermark import IFWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -856,6 +865,9 @@ class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): ...@@ -856,6 +865,9 @@ class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images) callback(i, t, intermediate_images)
if XLA_AVAILABLE:
xm.mark_step()
image = intermediate_images image = intermediate_images
if output_type == "pil": if output_type == "pil":
......
...@@ -35,6 +35,16 @@ if is_ftfy_available(): ...@@ -35,6 +35,16 @@ if is_ftfy_available():
import ftfy import ftfy
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -974,6 +984,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa ...@@ -974,6 +984,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images) callback(i, t, intermediate_images)
if XLA_AVAILABLE:
xm.mark_step()
image = intermediate_images image = intermediate_images
if output_type == "pil": if output_type == "pil":
......
...@@ -17,6 +17,7 @@ from ...utils import ( ...@@ -17,6 +17,7 @@ from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -27,8 +28,16 @@ from .safety_checker import IFSafetyChecker ...@@ -27,8 +28,16 @@ from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker from .watermark import IFWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -975,6 +984,9 @@ class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): ...@@ -975,6 +984,9 @@ class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images) callback(i, t, intermediate_images)
if XLA_AVAILABLE:
xm.mark_step()
image = intermediate_images image = intermediate_images
if output_type == "pil": if output_type == "pil":
......
...@@ -35,6 +35,16 @@ if is_ftfy_available(): ...@@ -35,6 +35,16 @@ if is_ftfy_available():
import ftfy import ftfy
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1085,6 +1095,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora ...@@ -1085,6 +1095,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images) callback(i, t, intermediate_images)
if XLA_AVAILABLE:
xm.mark_step()
image = intermediate_images image = intermediate_images
if output_type == "pil": if output_type == "pil":
......
...@@ -34,6 +34,16 @@ if is_ftfy_available(): ...@@ -34,6 +34,16 @@ if is_ftfy_available():
import ftfy import ftfy
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -831,6 +841,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi ...@@ -831,6 +841,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images) callback(i, t, intermediate_images)
if XLA_AVAILABLE:
xm.mark_step()
image = intermediate_images image = intermediate_images
if output_type == "pil": if output_type == "pil":
......
...@@ -24,10 +24,19 @@ import torch ...@@ -24,10 +24,19 @@ import torch
from ...models import AutoencoderKL, DiTTransformer2DModel from ...models import AutoencoderKL, DiTTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class DiTPipeline(DiffusionPipeline): class DiTPipeline(DiffusionPipeline):
r""" r"""
Pipeline for image generation based on a Transformer backbone instead of a UNet. Pipeline for image generation based on a Transformer backbone instead of a UNet.
...@@ -211,6 +220,9 @@ class DiTPipeline(DiffusionPipeline): ...@@ -211,6 +220,9 @@ class DiTPipeline(DiffusionPipeline):
# compute previous image: x_t -> x_t-1 # compute previous image: x_t -> x_t-1
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
if guidance_scale > 1: if guidance_scale > 1:
latents, _ = latent_model_input.chunk(2, dim=0) latents, _ = latent_model_input.chunk(2, dim=0)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment