"tools/vscode:/vscode.git/clone" did not exist on "840bb5bc647b629d4c2196cf5a6582e95cb2f35b"
Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
...@@ -19,6 +19,7 @@ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler ...@@ -19,6 +19,7 @@ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -29,8 +30,16 @@ from ..pipeline_utils import DiffusionPipeline ...@@ -29,8 +30,16 @@ from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1209,6 +1218,9 @@ class LEditsPPPipelineStableDiffusion( ...@@ -1209,6 +1218,9 @@ class LEditsPPPipelineStableDiffusion(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing # 8. Post-processing
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
...@@ -1378,6 +1390,9 @@ class LEditsPPPipelineStableDiffusion( ...@@ -1378,6 +1390,9 @@ class LEditsPPPipelineStableDiffusion(
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1) self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
zs = zs.flip(0) zs = zs.flip(0)
self.zs = zs self.zs = zs
......
...@@ -31,6 +31,7 @@ from ...utils import ( ...@@ -31,6 +31,7 @@ from ...utils import (
BACKENDS_MAPPING, BACKENDS_MAPPING,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -38,8 +39,16 @@ from ...utils.torch_utils import randn_tensor ...@@ -38,8 +39,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -874,6 +883,9 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -874,6 +883,9 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
latents = latents / self.vae.config.scaling_factor latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
......
...@@ -37,6 +37,7 @@ from ...schedulers import ( ...@@ -37,6 +37,7 @@ from ...schedulers import (
) )
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -46,6 +47,13 @@ from ..pipeline_utils import DiffusionPipeline ...@@ -46,6 +47,13 @@ from ..pipeline_utils import DiffusionPipeline
from .marigold_image_processing import MarigoldImageProcessor from .marigold_image_processing import MarigoldImageProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -517,6 +525,9 @@ class MarigoldDepthPipeline(DiffusionPipeline): ...@@ -517,6 +525,9 @@ class MarigoldDepthPipeline(DiffusionPipeline):
noise, t, batch_pred_latent, generator=generator noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w] ).prev_sample # [B,4,h,w]
if XLA_AVAILABLE:
xm.mark_step()
pred_latents.append(batch_pred_latent) pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
......
...@@ -36,6 +36,7 @@ from ...schedulers import ( ...@@ -36,6 +36,7 @@ from ...schedulers import (
) )
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -44,6 +45,13 @@ from ..pipeline_utils import DiffusionPipeline ...@@ -44,6 +45,13 @@ from ..pipeline_utils import DiffusionPipeline
from .marigold_image_processing import MarigoldImageProcessor from .marigold_image_processing import MarigoldImageProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -493,6 +501,9 @@ class MarigoldNormalsPipeline(DiffusionPipeline): ...@@ -493,6 +501,9 @@ class MarigoldNormalsPipeline(DiffusionPipeline):
noise, t, batch_pred_latent, generator=generator noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w] ).prev_sample # [B,4,h,w]
if XLA_AVAILABLE:
xm.mark_step()
pred_latents.append(batch_pred_latent) pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
......
...@@ -42,8 +42,20 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffu ...@@ -42,8 +42,20 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffu
if is_librosa_available(): if is_librosa_available():
import librosa import librosa
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -603,6 +615,9 @@ class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -603,6 +615,9 @@ class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
# 8. Post-processing # 8. Post-processing
......
...@@ -30,6 +30,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -30,6 +30,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker ...@@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1293,6 +1301,9 @@ class StableDiffusionControlNetPAGPipeline( ...@@ -1293,6 +1301,9 @@ class StableDiffusionControlNetPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker ...@@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1505,6 +1513,9 @@ class StableDiffusionControlNetPAGInpaintPipeline( ...@@ -1505,6 +1513,9 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -62,6 +62,16 @@ if is_invisible_watermark_available(): ...@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1564,6 +1574,9 @@ class StableDiffusionXLControlNetPAGPipeline( ...@@ -1564,6 +1574,9 @@ class StableDiffusionXLControlNetPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -62,6 +62,16 @@ if is_invisible_watermark_available(): ...@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1630,6 +1640,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline( ...@@ -1630,6 +1640,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -29,6 +29,7 @@ from ...utils import ( ...@@ -29,6 +29,7 @@ from ...utils import (
deprecate, deprecate,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN ...@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -843,6 +852,9 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -843,6 +852,9 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning: if use_resolution_binning:
......
...@@ -30,6 +30,7 @@ from ...utils import ( ...@@ -30,6 +30,7 @@ from ...utils import (
BACKENDS_MAPPING, BACKENDS_MAPPING,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN ...@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -867,6 +876,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -867,6 +876,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent": if output_type == "latent":
image = latents image = latents
else: else:
......
...@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -39,8 +40,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker ...@@ -39,8 +40,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1034,6 +1043,9 @@ class StableDiffusionPAGPipeline( ...@@ -1034,6 +1043,9 @@ class StableDiffusionPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0 0
......
...@@ -26,6 +26,7 @@ from ...models.unets.unet_motion_model import MotionAdapter ...@@ -26,6 +26,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -847,6 +856,9 @@ class AnimateDiffPAGPipeline( ...@@ -847,6 +856,9 @@ class AnimateDiffPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing # 9. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -42,8 +43,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker ...@@ -42,8 +43,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1066,6 +1075,9 @@ class StableDiffusionPAGImg2ImgPipeline( ...@@ -1066,6 +1075,9 @@ class StableDiffusionPAGImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0 0
......
...@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker ...@@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1318,6 +1327,9 @@ class StableDiffusionPAGInpaintPipeline( ...@@ -1318,6 +1327,9 @@ class StableDiffusionPAGInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
condition_kwargs = {} condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL): if isinstance(self.vae, AsymmetricAutoencoderKL):
......
...@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor ...@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
...@@ -31,6 +31,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker ...@@ -31,6 +31,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .image_encoder import PaintByExampleImageEncoder from .image_encoder import PaintByExampleImageEncoder
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -604,6 +611,9 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -604,6 +611,9 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not output_type == "latent": if not output_type == "latent":
......
...@@ -37,6 +37,7 @@ from ...schedulers import ( ...@@ -37,6 +37,7 @@ from ...schedulers import (
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -48,8 +49,16 @@ from ..free_init_utils import FreeInitMixin ...@@ -48,8 +49,16 @@ from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -928,6 +937,9 @@ class PIAPipeline( ...@@ -928,6 +937,9 @@ class PIAPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing # 9. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -29,6 +29,7 @@ from ...utils import ( ...@@ -29,6 +29,7 @@ from ...utils import (
deprecate, deprecate,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -36,8 +37,16 @@ from ...utils.torch_utils import randn_tensor ...@@ -36,8 +37,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -943,6 +952,9 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -943,6 +952,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning: if use_resolution_binning:
......
...@@ -29,6 +29,7 @@ from ...utils import ( ...@@ -29,6 +29,7 @@ from ...utils import (
deprecate, deprecate,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -41,8 +42,16 @@ from .pipeline_pixart_alpha import ( ...@@ -41,8 +42,16 @@ from .pipeline_pixart_alpha import (
) )
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -854,6 +863,9 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -854,6 +863,9 @@ class PixArtSigmaPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning: if use_resolution_binning:
......
...@@ -9,12 +9,19 @@ from ...image_processor import VaeImageProcessor ...@@ -9,12 +9,19 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import SemanticStableDiffusionPipelineOutput from .pipeline_output import SemanticStableDiffusionPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -701,6 +708,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -701,6 +708,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing # 8. Post-processing
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment