Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
...@@ -33,6 +33,7 @@ from ...utils import ( ...@@ -33,6 +33,7 @@ from ...utils import (
deprecate, deprecate,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor ...@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor
from .pipeline_output import AllegroPipelineOutput from .pipeline_output import AllegroPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_bs4_available(): if is_bs4_available():
...@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents) video = self.decode_latents(latents)
......
...@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer ...@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -299,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline): ...@@ -299,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents) callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent": if output_type == "latent":
output = latents output = latents
else: else:
......
...@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer ...@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -325,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline): ...@@ -325,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents) callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent": if output_type == "latent":
output = latents output = latents
else: else:
......
...@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer ...@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -356,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline): ...@@ -356,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents) callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent": if output_type == "latent":
output = latents output = latents
else: else:
......
...@@ -34,6 +34,7 @@ from ...schedulers import ( ...@@ -34,6 +34,7 @@ from ...schedulers import (
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -844,6 +853,9 @@ class AnimateDiffPipeline( ...@@ -844,6 +853,9 @@ class AnimateDiffPipeline(
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing # 9. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -32,7 +32,7 @@ from ...models import ( ...@@ -32,7 +32,7 @@ from ...models import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
...@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline( ...@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing # 9. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -48,6 +48,7 @@ from ...schedulers import ( ...@@ -48,6 +48,7 @@ from ...schedulers import (
) )
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -60,8 +61,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -60,8 +61,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1265,6 +1274,9 @@ class AnimateDiffSDXLPipeline( ...@@ -1265,6 +1274,9 @@ class AnimateDiffSDXLPipeline(
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -30,6 +30,7 @@ from ...models.unets.unet_motion_model import MotionAdapter ...@@ -30,6 +30,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -42,8 +43,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -42,8 +43,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```python ```python
...@@ -994,6 +1003,9 @@ class AnimateDiffSparseControlNetPipeline( ...@@ -994,6 +1003,9 @@ class AnimateDiffSparseControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 11. Post processing # 11. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -31,7 +31,7 @@ from ...schedulers import ( ...@@ -31,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
...@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1037,6 +1045,9 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -1037,6 +1045,9 @@ class AnimateDiffVideoToVideoPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 10. Post-processing # 10. Post-processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -39,7 +39,7 @@ from ...schedulers import ( ...@@ -39,7 +39,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
...@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1325,6 +1333,9 @@ class AnimateDiffVideoToVideoControlNetPipeline( ...@@ -1325,6 +1333,9 @@ class AnimateDiffVideoToVideoControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 11. Post-processing # 11. Post-processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT ...@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing # 8. Post-processing
mel_spectrogram = self.decode_latents(latents) mel_spectrogram = self.decode_latents(latents)
......
...@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi ...@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi
if is_librosa_available(): if is_librosa_available():
import librosa import librosa
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1033,6 +1045,9 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -1033,6 +1045,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
# 8. Post-processing # 8. Post-processing
......
...@@ -20,6 +20,7 @@ from transformers import CLIPTokenizer ...@@ -20,6 +20,7 @@ from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -30,8 +31,16 @@ from .modeling_blip2 import Blip2QFormerModel ...@@ -30,8 +31,16 @@ from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel from .modeling_ctx_clip import ContextCLIPTextModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -336,6 +345,9 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -336,6 +345,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents, latents,
)["prev_sample"] )["prev_sample"]
if XLA_AVAILABLE:
xm.mark_step()
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
......
...@@ -26,12 +26,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel ...@@ -26,12 +26,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -753,6 +760,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -753,6 +760,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5 # Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:] latents = latents[:, additional_frames:]
......
...@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel ...@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -808,6 +815,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -808,6 +815,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
...@@ -29,6 +29,7 @@ from ...models.embeddings import get_3d_rotary_pos_embed ...@@ -29,6 +29,7 @@ from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -37,6 +38,13 @@ from ...video_processor import VideoProcessor ...@@ -37,6 +38,13 @@ from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -866,6 +874,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -866,6 +874,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5 # Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:] latents = latents[:, additional_frames:]
......
...@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel ...@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput from .pipeline_output import CogVideoXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -834,6 +841,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -834,6 +841,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
...@@ -24,11 +24,18 @@ from ...image_processor import VaeImageProcessor ...@@ -24,11 +24,18 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView3PipelineOutput from .pipeline_output import CogView3PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -654,6 +661,9 @@ class CogView3PlusPipeline(DiffusionPipeline): ...@@ -654,6 +661,9 @@ class CogView3PlusPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0 0
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler from ...schedulers import CMStochasticIterativeScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -26,6 +27,13 @@ from ...utils.torch_utils import randn_tensor ...@@ -26,6 +27,13 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -263,6 +271,9 @@ class ConsistencyModelPipeline(DiffusionPipeline): ...@@ -263,6 +271,9 @@ class ConsistencyModelPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, sample) callback(i, t, sample)
if XLA_AVAILABLE:
xm.mark_step()
# 6. Post-process image sample # 6. Post-process image sample
image = self.postprocess_image(sample, output_type=output_type) image = self.postprocess_image(sample, output_type=output_type)
......
...@@ -21,6 +21,7 @@ from transformers import CLIPTokenizer ...@@ -21,6 +21,7 @@ from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -31,8 +32,16 @@ from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel ...@@ -31,8 +32,16 @@ from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -401,6 +410,10 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -401,6 +410,10 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
t, t,
latents, latents,
)["prev_sample"] )["prev_sample"]
if XLA_AVAILABLE:
xm.mark_step()
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment