Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
...@@ -24,14 +24,22 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection ...@@ -24,14 +24,22 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging, replace_example_docstring from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -600,6 +608,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -600,6 +608,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# cast back to fp16 if needed # cast back to fp16 if needed
if needs_upcasting: if needs_upcasting:
......
...@@ -31,6 +31,7 @@ from ...utils import ( ...@@ -31,6 +31,7 @@ from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -41,6 +42,14 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -41,6 +42,14 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
@dataclass @dataclass
class StableDiffusionAdapterPipelineOutput(BaseOutput): class StableDiffusionAdapterPipelineOutput(BaseOutput):
""" """
...@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): ...@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput):
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -915,6 +925,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -915,6 +925,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent": if output_type == "latent":
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
......
...@@ -43,6 +43,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -43,6 +43,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -53,8 +54,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -53,8 +54,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -1266,6 +1275,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1266,6 +1275,9 @@ class StableDiffusionXLAdapterPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -25,6 +25,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -25,6 +25,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -36,8 +37,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -36,8 +37,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -627,6 +636,9 @@ class TextToVideoSDPipeline( ...@@ -627,6 +636,9 @@ class TextToVideoSDPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post processing # 8. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -26,6 +26,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -26,6 +26,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -679,6 +688,9 @@ class VideoToVideoSDPipeline( ...@@ -679,6 +688,9 @@ class VideoToVideoSDPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
......
...@@ -42,6 +42,16 @@ if is_invisible_watermark_available(): ...@@ -42,6 +42,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -926,6 +936,10 @@ class TextToVideoZeroSDXLPipeline( ...@@ -926,6 +936,10 @@ class TextToVideoZeroSDXLPipeline(
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
return latents.clone().detach() return latents.clone().detach()
@torch.no_grad() @torch.no_grad()
......
...@@ -22,12 +22,19 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput ...@@ -22,12 +22,19 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import logging from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -474,6 +481,9 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -474,6 +481,9 @@ class UnCLIPPipeline(DiffusionPipeline):
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
image = super_res_latents image = super_res_latents
# done super res # done super res
......
...@@ -27,12 +27,19 @@ from transformers import ( ...@@ -27,12 +27,19 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel from ...models import UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import logging from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -400,6 +407,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -400,6 +407,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
image = super_res_latents image = super_res_latents
# done super res # done super res
......
...@@ -18,7 +18,14 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix ...@@ -18,7 +18,14 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.outputs import BaseOutput from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -26,6 +33,13 @@ from .modeling_text_decoder import UniDiffuserTextDecoder ...@@ -26,6 +33,13 @@ from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel from .modeling_uvit import UniDiffuserModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1378,6 +1392,9 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1378,6 +1392,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post-processing # 9. Post-processing
image = None image = None
text = None text = None
......
...@@ -19,15 +19,23 @@ import torch ...@@ -19,15 +19,23 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import deprecate, logging, replace_example_docstring from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
......
...@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import StableDiffusionLoraLoaderMixin from ...loaders import StableDiffusionLoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior from .modeling_wuerstchen_prior import WuerstchenPrior
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
...@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin) ...@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 10. Denormalize the latents # 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std latents = latents * self.config.latent_mean - self.config.latent_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment