"vscode:/vscode.git/clone" did not exist on "0103f374ba47c7baa87f9378b0ba3ef6c282969d"
Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
......@@ -24,14 +24,22 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -600,6 +608,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# cast back to fp16 if needed
if needs_upcasting:
......
......@@ -31,6 +31,7 @@ from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -41,6 +42,14 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
@dataclass
class StableDiffusionAdapterPipelineOutput(BaseOutput):
"""
......@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput):
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -915,6 +925,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
has_nsfw_concept = None
......
......@@ -43,6 +43,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -53,8 +54,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -1266,6 +1275,9 @@ class StableDiffusionXLAdapterPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -25,6 +25,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -36,8 +37,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -627,6 +636,9 @@ class TextToVideoSDPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post processing
if output_type == "latent":
video = latents
......
......@@ -26,6 +26,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -679,6 +688,9 @@ class VideoToVideoSDPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
......
......@@ -42,6 +42,16 @@ if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ...utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -926,6 +936,10 @@ class TextToVideoZeroSDXLPipeline(
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
return latents.clone().detach()
@torch.no_grad()
......
......@@ -22,12 +22,19 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
from ...utils import logging
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -474,6 +481,9 @@ class UnCLIPPipeline(DiffusionPipeline):
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
image = super_res_latents
# done super res
......
......@@ -27,12 +27,19 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
from ...utils import logging
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -400,6 +407,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
image = super_res_latents
# done super res
......
......@@ -18,7 +18,14 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
......@@ -26,6 +33,13 @@ from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1378,6 +1392,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post-processing
image = None
text = None
......
......@@ -19,15 +19,23 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import deprecate, logging, replace_example_docstring
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
......
......@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import StableDiffusionLoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
......@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment