Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
......@@ -33,6 +33,7 @@ from ...utils import (
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor
from .pipeline_output import AllegroPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
if is_bs4_available():
......@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
......
......@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -299,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
......
......@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -325,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
......
......@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -356,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
......
......@@ -34,6 +34,7 @@ from ...schedulers import (
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -844,6 +853,9 @@ class AnimateDiffPipeline(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents
......
......@@ -32,7 +32,7 @@ from ...models import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
......@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents
......
......@@ -48,6 +48,7 @@ from ...schedulers import (
)
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -60,8 +61,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1265,6 +1274,9 @@ class AnimateDiffSDXLPipeline(
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -30,6 +30,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -42,8 +43,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
......@@ -994,6 +1003,9 @@ class AnimateDiffSparseControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 11. Post processing
if output_type == "latent":
video = latents
......
......@@ -31,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
......@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1037,6 +1045,9 @@ class AnimateDiffVideoToVideoPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 10. Post-processing
if output_type == "latent":
video = latents
......
......@@ -39,7 +39,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
......@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1325,6 +1333,9 @@ class AnimateDiffVideoToVideoControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 11. Post-processing
if output_type == "latent":
video = latents
......
......@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing
mel_spectrogram = self.decode_latents(latents)
......
......@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi
if is_librosa_available():
import librosa
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1033,6 +1045,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks()
# 8. Post-processing
......
......@@ -20,6 +20,7 @@ from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -30,8 +31,16 @@ from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -336,6 +345,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents,
)["prev_sample"]
if XLA_AVAILABLE:
xm.mark_step()
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
......
......@@ -26,12 +26,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -753,6 +760,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
......
......@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -808,6 +815,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
......@@ -29,6 +29,7 @@ from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -37,6 +38,13 @@ from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -866,6 +874,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
......
......@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -834,6 +841,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
......@@ -24,11 +24,18 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView3PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -654,6 +661,9 @@ class CogView3PlusPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
......
......@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -26,6 +27,13 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -263,6 +271,9 @@ class ConsistencyModelPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, sample)
if XLA_AVAILABLE:
xm.mark_step()
# 6. Post-process image sample
image = self.postprocess_image(sample, output_type=output_type)
......
......@@ -21,6 +21,7 @@ from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -31,8 +32,16 @@ from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -401,6 +410,10 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
t,
latents,
)["prev_sample"]
if XLA_AVAILABLE:
xm.mark_step()
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment