Unverified Commit aea73834 authored by ssusie's avatar ssusie Committed by GitHub
Browse files

Adding PyTorch XLA support for sdxl inference (#5273)



* Added  mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla

* adding soft dependency on torch_xla

* fix some styling

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent a1392135
...@@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] ...@@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png") image.save("pokemon.png")
``` ```
### Inference in Pytorch XLA
```python
from diffusers import DiffusionPipeline
import torch
import torch_xla.core.xla_model as xm
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id)
device = xm.xla_device()
pipe.to(device)
prompt = "A pokemon with green eyes and red legs."
start = time()
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
print(f'Compilation time is {time()-start} sec')
image.save("pokemon.png")
start = time()
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
print(f'Inference time is {time()-start} sec after compilation')
```
Note: There is a warmup step in PyTorch XLA. This takes longer because of
compilation and optimization. To see the real benefits of Pytorch XLA and
speedup, we need to call the pipe again on the input with the same length
as the original prompt to reuse the optimized graph and get the performance
boost.
## LoRA training example for Stable Diffusion XL (SDXL) ## LoRA training example for Stable Diffusion XL (SDXL)
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
......
...@@ -35,6 +35,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -35,6 +35,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_invisible_watermark_available, is_invisible_watermark_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -48,6 +49,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput ...@@ -48,6 +49,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -860,7 +868,7 @@ class StableDiffusionXLPipeline( ...@@ -860,7 +868,7 @@ class StableDiffusionXLPipeline(
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end # 8.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int( discrete_timestep_cutoff = int(
round( round(
...@@ -908,6 +916,9 @@ class StableDiffusionXLPipeline( ...@@ -908,6 +916,9 @@ class StableDiffusionXLPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -32,6 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -32,6 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_invisible_watermark_available, is_invisible_watermark_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -45,6 +46,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput ...@@ -45,6 +46,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1031,6 +1039,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1031,6 +1039,9 @@ class StableDiffusionXLImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -34,6 +34,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -34,6 +34,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
is_invisible_watermark_available, is_invisible_watermark_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -47,6 +48,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput ...@@ -47,6 +48,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1355,6 +1363,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1355,6 +1363,9 @@ class StableDiffusionXLInpaintPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -33,6 +33,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -33,6 +33,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
is_invisible_watermark_available, is_invisible_watermark_available,
is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
...@@ -44,6 +45,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput ...@@ -44,6 +45,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -926,6 +934,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -926,6 +934,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
...@@ -73,6 +73,7 @@ from .import_utils import ( ...@@ -73,6 +73,7 @@ from .import_utils import (
is_tensorboard_available, is_tensorboard_available,
is_torch_available, is_torch_available,
is_torch_version, is_torch_version,
is_torch_xla_available,
is_torchsde_available, is_torchsde_available,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
......
...@@ -64,6 +64,14 @@ else: ...@@ -64,6 +64,14 @@ else:
logger.info("Disabling PyTorch because USE_TORCH is set") logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False _torch_available = False
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
try:
_torch_xla_version = importlib_metadata.version("torch_xla")
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
except ImportError:
_torch_xla_available = False
_jax_version = "N/A" _jax_version = "N/A"
_flax_version = "N/A" _flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
...@@ -281,6 +289,10 @@ def is_torch_available(): ...@@ -281,6 +289,10 @@ def is_torch_available():
return _torch_available return _torch_available
def is_torch_xla_available():
return _torch_xla_available
def is_flax_available(): def is_flax_available():
return _flax_available return _flax_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment