"src/diffusers/models/unet_blocks.py" did not exist on "107986639d7d07d893cfda2492011d3b81022d46"
Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
......@@ -23,15 +23,23 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanVideoPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
......@@ -667,6 +675,9 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
......
......@@ -27,6 +27,7 @@ from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
from ...schedulers import DDIMScheduler
from ...utils import (
BaseOutput,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -35,8 +36,16 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -711,6 +720,9 @@ class I2VGenXLPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post processing
if output_type == "latent":
video = latents
......
......@@ -22,6 +22,7 @@ from transformers import (
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -30,8 +31,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -385,6 +394,9 @@ class KandinskyPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -25,6 +25,7 @@ from transformers import (
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -478,6 +487,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -29,6 +29,7 @@ from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -613,6 +622,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -24,6 +24,7 @@ from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
BaseOutput,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -31,8 +32,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -519,6 +528,9 @@ class KandinskyPriorPipeline(DiffusionPipeline):
prev_timestep=prev_timestep,
).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
......
......@@ -18,13 +18,21 @@ import torch
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import deprecate, logging, replace_example_docstring
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -296,6 +304,9 @@ class KandinskyV22Pipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
......
......@@ -19,14 +19,23 @@ import torch
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -297,6 +306,10 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -22,14 +22,23 @@ from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -358,6 +367,9 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -21,13 +21,21 @@ from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import deprecate, logging
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -372,6 +380,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
......
......@@ -25,13 +25,21 @@ from PIL import Image
from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import deprecate, logging
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -526,6 +534,9 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
......
......@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -524,6 +533,9 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
)
text_mask = callback_outputs.pop("text_mask", text_mask)
if XLA_AVAILABLE:
xm.mark_step()
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
......
......@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -538,6 +547,9 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
prev_timestep=prev_timestep,
).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
......
......@@ -8,6 +8,7 @@ from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -15,8 +16,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -549,6 +558,9 @@ class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
......
......@@ -12,6 +12,7 @@ from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -19,8 +20,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -617,6 +626,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
......
......@@ -30,6 +30,7 @@ from ...schedulers import LCMScheduler
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -40,6 +41,13 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -952,6 +960,9 @@ class LatentConsistencyModelImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
......
......@@ -29,6 +29,7 @@ from ...schedulers import LCMScheduler
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -39,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -881,6 +890,9 @@ class LatentConsistencyModelPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
......
......@@ -25,10 +25,19 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class LDMTextToImagePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using latent diffusion.
......@@ -202,6 +211,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
# scale and decode the image latents with vae
latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample
......
......@@ -15,11 +15,19 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION
from ...utils import PIL_INTERPOLATION, is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
def preprocess(image):
w, h = image.size
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
......@@ -174,6 +182,9 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
# decode the image latents with the VQVAE
image = self.vqvae.decode(latents).sample
image = torch.clamp(image, -1.0, 1.0)
......
......@@ -32,6 +32,7 @@ from ...utils import (
BaseOutput,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -39,8 +40,16 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
......@@ -836,6 +845,9 @@ class LattePipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latents":
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment