"vscode:/vscode.git/clone" did not exist on "fb8d91375a745644d78d6f5e118eec49bffa4b22"
Unverified Commit 95c5ce4e authored by hlky's avatar hlky Committed by GitHub
Browse files

PyTorch/XLA support (#10498)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c0964571
...@@ -23,15 +23,23 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback ...@@ -23,15 +23,23 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import HunyuanVideoLoraLoaderMixin from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanVideoPipelineOutput from .pipeline_output import HunyuanVideoPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```python ```python
...@@ -667,6 +675,9 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -667,6 +675,9 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]
......
...@@ -27,6 +27,7 @@ from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet ...@@ -27,6 +27,7 @@ from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -35,8 +36,16 @@ from ...video_processor import VideoProcessor ...@@ -35,8 +36,16 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -711,6 +720,9 @@ class I2VGenXLPipeline( ...@@ -711,6 +720,9 @@ class I2VGenXLPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post processing # 8. Post processing
if output_type == "latent": if output_type == "latent":
video = latents video = latents
......
...@@ -22,6 +22,7 @@ from transformers import ( ...@@ -22,6 +22,7 @@ from transformers import (
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler, DDPMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -30,8 +31,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -30,8 +31,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -385,6 +394,9 @@ class KandinskyPipeline(DiffusionPipeline): ...@@ -385,6 +394,9 @@ class KandinskyPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -25,6 +25,7 @@ from transformers import ( ...@@ -25,6 +25,7 @@ from transformers import (
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -478,6 +487,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): ...@@ -478,6 +487,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 7. post-processing # 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -29,6 +29,7 @@ from ... import __version__ ...@@ -29,6 +29,7 @@ from ... import __version__
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -613,6 +622,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline): ...@@ -613,6 +622,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -24,6 +24,7 @@ from ...models import PriorTransformer ...@@ -24,6 +24,7 @@ from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -31,8 +32,16 @@ from ...utils.torch_utils import randn_tensor ...@@ -31,8 +32,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -519,6 +528,9 @@ class KandinskyPriorPipeline(DiffusionPipeline): ...@@ -519,6 +528,9 @@ class KandinskyPriorPipeline(DiffusionPipeline):
prev_timestep=prev_timestep, prev_timestep=prev_timestep,
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
latents = self.prior.post_process_latents(latents) latents = self.prior.post_process_latents(latents)
image_embeddings = latents image_embeddings = latents
......
...@@ -18,13 +18,21 @@ import torch ...@@ -18,13 +18,21 @@ import torch
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import deprecate, logging, replace_example_docstring from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -296,6 +304,9 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -296,6 +304,9 @@ class KandinskyV22Pipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
......
...@@ -19,14 +19,23 @@ import torch ...@@ -19,14 +19,23 @@ import torch
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -297,6 +306,10 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline): ...@@ -297,6 +306,10 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -22,14 +22,23 @@ from PIL import Image ...@@ -22,14 +22,23 @@ from PIL import Image
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -358,6 +367,9 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): ...@@ -358,6 +367,9 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -21,13 +21,21 @@ from PIL import Image ...@@ -21,13 +21,21 @@ from PIL import Image
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -372,6 +380,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -372,6 +380,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}" f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
......
...@@ -25,13 +25,21 @@ from PIL import Image ...@@ -25,13 +25,21 @@ from PIL import Image
from ... import __version__ from ... import __version__
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import deprecate, logging from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -526,6 +534,9 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -526,6 +534,9 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
......
...@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo ...@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput ...@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -524,6 +533,9 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -524,6 +533,9 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
) )
text_mask = callback_outputs.pop("text_mask", text_mask) text_mask = callback_outputs.pop("text_mask", text_mask)
if XLA_AVAILABLE:
xm.mark_step()
latents = self.prior.post_process_latents(latents) latents = self.prior.post_process_latents(latents)
image_embeddings = latents image_embeddings = latents
......
...@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo ...@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput ...@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -538,6 +547,9 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): ...@@ -538,6 +547,9 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
prev_timestep=prev_timestep, prev_timestep=prev_timestep,
).prev_sample ).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
latents = self.prior.post_process_latents(latents) latents = self.prior.post_process_latents(latents)
image_embeddings = latents image_embeddings = latents
......
...@@ -8,6 +8,7 @@ from ...models import Kandinsky3UNet, VQModel ...@@ -8,6 +8,7 @@ from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -15,8 +16,16 @@ from ...utils.torch_utils import randn_tensor ...@@ -15,8 +16,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -549,6 +558,9 @@ class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): ...@@ -549,6 +558,9 @@ class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
if output_type not in ["pt", "np", "pil", "latent"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
......
...@@ -12,6 +12,7 @@ from ...models import Kandinsky3UNet, VQModel ...@@ -12,6 +12,7 @@ from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -19,8 +20,16 @@ from ...utils.torch_utils import randn_tensor ...@@ -19,8 +20,16 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -617,6 +626,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi ...@@ -617,6 +626,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# post-processing # post-processing
if output_type not in ["pt", "np", "pil", "latent"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
......
...@@ -30,6 +30,7 @@ from ...schedulers import LCMScheduler ...@@ -30,6 +30,7 @@ from ...schedulers import LCMScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -40,6 +41,13 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -40,6 +41,13 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -952,6 +960,9 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -952,6 +960,9 @@ class LatentConsistencyModelImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
denoised = denoised.to(prompt_embeds.dtype) denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -29,6 +29,7 @@ from ...schedulers import LCMScheduler ...@@ -29,6 +29,7 @@ from ...schedulers import LCMScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -39,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin ...@@ -39,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -881,6 +890,9 @@ class LatentConsistencyModelPipeline( ...@@ -881,6 +890,9 @@ class LatentConsistencyModelPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
denoised = denoised.to(prompt_embeds.dtype) denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -25,10 +25,19 @@ from transformers.utils import logging ...@@ -25,10 +25,19 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class LDMTextToImagePipeline(DiffusionPipeline): class LDMTextToImagePipeline(DiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using latent diffusion. Pipeline for text-to-image generation using latent diffusion.
...@@ -202,6 +211,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -202,6 +211,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / self.vqvae.config.scaling_factor * latents latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample image = self.vqvae.decode(latents).sample
......
...@@ -15,11 +15,19 @@ from ...schedulers import ( ...@@ -15,11 +15,19 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION from ...utils import PIL_INTERPOLATION, is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
...@@ -174,6 +182,9 @@ class LDMSuperResolutionPipeline(DiffusionPipeline): ...@@ -174,6 +182,9 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
if XLA_AVAILABLE:
xm.mark_step()
# decode the image latents with the VQVAE # decode the image latents with the VQVAE
image = self.vqvae.decode(latents).sample image = self.vqvae.decode(latents).sample
image = torch.clamp(image, -1.0, 1.0) image = torch.clamp(image, -1.0, 1.0)
......
...@@ -32,6 +32,7 @@ from ...utils import ( ...@@ -32,6 +32,7 @@ from ...utils import (
BaseOutput, BaseOutput,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -39,8 +40,16 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor ...@@ -39,8 +40,16 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available(): if is_bs4_available():
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
...@@ -836,6 +845,9 @@ class LattePipeline(DiffusionPipeline): ...@@ -836,6 +845,9 @@ class LattePipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latents": if not output_type == "latents":
video = self.decode_latents(latents, video_length, decode_chunk_size=14) video = self.decode_latents(latents, video_length, decode_chunk_size=14)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment