Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
...@@ -25,6 +25,7 @@ from ...models import PriorTransformer ...@@ -25,6 +25,7 @@ from ...models import PriorTransformer
from ...schedulers import HeunDiscreteScheduler from ...schedulers import HeunDiscreteScheduler
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline ...@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer from .renderer import ShapERenderer
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -291,6 +300,9 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -291,6 +300,9 @@ class ShapEPipeline(DiffusionPipeline):
sample=latents, sample=latents,
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -24,6 +24,7 @@ from ...models import PriorTransformer ...@@ -24,6 +24,7 @@ from ...models import PriorTransformer
from ...schedulers import HeunDiscreteScheduler from ...schedulers import HeunDiscreteScheduler
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -32,8 +33,16 @@ from ..pipeline_utils import DiffusionPipeline ...@@ -32,8 +33,16 @@ from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer from .renderer import ShapERenderer
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -278,6 +287,9 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -278,6 +287,9 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
sample=latents, sample=latents,
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["np", "pil", "latent", "mesh"]: if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError( raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
......
...@@ -19,14 +19,22 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -19,14 +19,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, logging, replace_example_docstring from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -503,6 +511,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -503,6 +511,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
......
...@@ -23,13 +23,21 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo ...@@ -23,13 +23,21 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, logging, replace_example_docstring from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
...@@ -611,6 +619,9 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -611,6 +619,9 @@ class StableCascadePriorPipeline(DiffusionPipeline):
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if XLA_AVAILABLE:
xm.mark_step()
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -28,11 +28,26 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix ...@@ -28,11 +28,26 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -861,6 +876,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -861,6 +876,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else: else:
......
...@@ -24,13 +24,20 @@ from ...configuration_utils import FrozenDict ...@@ -24,13 +24,20 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -401,6 +408,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi ...@@ -401,6 +408,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not output_type == "latent": if not output_type == "latent":
......
...@@ -32,6 +32,7 @@ from ...utils import ( ...@@ -32,6 +32,7 @@ from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -43,8 +44,16 @@ from . import StableDiffusionPipelineOutput ...@@ -43,8 +44,16 @@ from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1120,6 +1129,9 @@ class StableDiffusionImg2ImgPipeline( ...@@ -1120,6 +1129,9 @@ class StableDiffusionImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0 0
......
...@@ -27,13 +27,27 @@ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraL ...@@ -27,13 +27,27 @@ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraL
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1303,6 +1317,9 @@ class StableDiffusionInpaintPipeline( ...@@ -1303,6 +1317,9 @@ class StableDiffusionInpaintPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
condition_kwargs = {} condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL): if isinstance(self.vae, AsymmetricAutoencoderKL):
......
...@@ -25,11 +25,18 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor ...@@ -25,11 +25,18 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -640,6 +647,9 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -640,6 +647,9 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else: else:
......
...@@ -30,12 +30,26 @@ from ...models.attention_processor import ( ...@@ -30,12 +30,26 @@ from ...models.attention_processor import (
) )
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -769,6 +783,9 @@ class StableDiffusionUpscalePipeline( ...@@ -769,6 +783,9 @@ class StableDiffusionUpscalePipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu ...@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -924,6 +933,9 @@ class StableUnCLIPPipeline( ...@@ -924,6 +933,9 @@ class StableUnCLIPPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else: else:
......
...@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu ...@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -829,6 +838,9 @@ class StableUnCLIPImg2ImgPipeline( ...@@ -829,6 +838,9 @@ class StableUnCLIPImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post-processing # 9. Post-processing
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -41,6 +42,14 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -41,6 +42,14 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
...@@ -1008,6 +1017,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM ...@@ -1008,6 +1017,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing # 8. Post-processing
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -33,6 +33,7 @@ from ...utils import ( ...@@ -33,6 +33,7 @@ from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -44,6 +45,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -44,6 +45,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1508,6 +1516,9 @@ class StableDiffusionDiffEditPipeline( ...@@ -1508,6 +1516,9 @@ class StableDiffusionDiffEditPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
...@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -40,8 +41,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -40,8 +41,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -828,6 +837,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -828,6 +837,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
...@@ -32,7 +32,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel ...@@ -32,7 +32,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import GatedSelfAttentionDense from ...models.attention import GatedSelfAttentionDense
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
...@@ -40,8 +47,16 @@ from ..stable_diffusion.clip_image_project_model import CLIPImageProjection ...@@ -40,8 +47,16 @@ from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1010,6 +1025,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM ...@@ -1010,6 +1025,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
...@@ -30,6 +30,7 @@ from ...utils import ( ...@@ -30,6 +30,7 @@ from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```python ```python
...@@ -1002,6 +1011,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -1002,6 +1011,9 @@ class StableDiffusionLDM3DPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
...@@ -26,6 +26,7 @@ from ...schedulers import DDIMScheduler ...@@ -26,6 +26,7 @@ from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -37,8 +38,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -37,8 +38,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1155,6 +1164,9 @@ class StableDiffusionPanoramaPipeline( ...@@ -1155,6 +1164,9 @@ class StableDiffusionPanoramaPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type != "latent": if output_type != "latent":
if circular_padding: if circular_padding:
image = self.decode_latents_with_padding(latents) image = self.decode_latents_with_padding(latents)
......
...@@ -12,13 +12,20 @@ from ...image_processor import PipelineImageInput ...@@ -12,13 +12,20 @@ from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin from ...loaders import IPAdapterMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionSafePipelineOutput from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -739,6 +746,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda ...@@ -739,6 +746,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing # 8. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
......
...@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -38,8 +39,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput ...@@ -38,8 +39,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -840,6 +849,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua ...@@ -840,6 +849,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment