Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
......@@ -19,6 +19,7 @@ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -29,8 +30,16 @@ from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1209,6 +1218,9 @@ class LEditsPPPipelineStableDiffusion(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
......@@ -1378,6 +1390,9 @@ class LEditsPPPipelineStableDiffusion(
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
zs = zs.flip(0)
self.zs = zs
......
......@@ -31,6 +31,7 @@ from ...utils import (
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -38,8 +39,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
......@@ -874,6 +883,9 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
......
......@@ -37,6 +37,7 @@ from ...schedulers import (
)
from ...utils import (
BaseOutput,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -46,6 +47,13 @@ from ..pipeline_utils import DiffusionPipeline
from .marigold_image_processing import MarigoldImageProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -517,6 +525,9 @@ class MarigoldDepthPipeline(DiffusionPipeline):
noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w]
if XLA_AVAILABLE:
xm.mark_step()
pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
......
......@@ -36,6 +36,7 @@ from ...schedulers import (
)
from ...utils import (
BaseOutput,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -44,6 +45,13 @@ from ..pipeline_utils import DiffusionPipeline
from .marigold_image_processing import MarigoldImageProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -493,6 +501,9 @@ class MarigoldNormalsPipeline(DiffusionPipeline):
noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w]
if XLA_AVAILABLE:
xm.mark_step()
pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
......
......@@ -42,8 +42,20 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffu
if is_librosa_available():
import librosa
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -603,6 +615,9 @@ class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks()
# 8. Post-processing
......
......@@ -30,6 +30,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1293,6 +1301,9 @@ class StableDiffusionControlNetPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
......@@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1505,6 +1513,9 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
......@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1564,6 +1574,9 @@ class StableDiffusionXLControlNetPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1630,6 +1640,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
......@@ -29,6 +29,7 @@ from ...utils import (
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
......@@ -843,6 +852,9 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
......
......@@ -30,6 +30,7 @@ from ...utils import (
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
......@@ -867,6 +876,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
......
......@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -39,8 +40,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1034,6 +1043,9 @@ class StableDiffusionPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
......
......@@ -26,6 +26,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -847,6 +856,9 @@ class AnimateDiffPAGPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents
......
......@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -42,8 +43,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1066,6 +1075,9 @@ class StableDiffusionPAGImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
......
......@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1318,6 +1327,9 @@ class StableDiffusionPAGInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL):
......
......@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
......@@ -31,6 +31,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .image_encoder import PaintByExampleImageEncoder
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -604,6 +611,9 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks()
if not output_type == "latent":
......
......@@ -37,6 +37,7 @@ from ...schedulers import (
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -48,8 +49,16 @@ from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -928,6 +937,9 @@ class PIAPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents
......
......@@ -29,6 +29,7 @@ from ...utils import (
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -36,8 +37,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
......@@ -943,6 +952,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
......
......@@ -29,6 +29,7 @@ from ...utils import (
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -41,8 +42,16 @@ from .pipeline_pixart_alpha import (
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
......@@ -854,6 +863,9 @@ class PixArtSigmaPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
......
......@@ -9,12 +9,19 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import SemanticStableDiffusionPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -701,6 +708,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment