Unverified Commit aea73834 authored by ssusie's avatar ssusie Committed by GitHub
Browse files

Adding PyTorch XLA support for sdxl inference (#5273)



* Added  mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla

* adding soft dependency on torch_xla

* fix some styling

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent a1392135
......@@ -44,7 +44,7 @@ from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
### Training
......@@ -73,10 +73,10 @@ accelerate launch train_text_to_image_sdxl.py \
--push_to_hub
```
**Notes**:
**Notes**:
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
......@@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
```
### Inference in Pytorch XLA
```python
from diffusers import DiffusionPipeline
import torch
import torch_xla.core.xla_model as xm
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id)
device = xm.xla_device()
pipe.to(device)
prompt = "A pokemon with green eyes and red legs."
start = time()
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
print(f'Compilation time is {time()-start} sec')
image.save("pokemon.png")
start = time()
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
print(f'Inference time is {time()-start} sec after compilation')
```
Note: There is a warmup step in PyTorch XLA. This takes longer because of
compilation and optimization. To see the real benefits of Pytorch XLA and
speedup, we need to call the pipe again on the input with the same length
as the original prompt to reuse the optimized graph and get the performance
boost.
## LoRA training example for Stable Diffusion XL (SDXL)
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
......@@ -112,7 +141,7 @@ on consumer GPUs like Tesla T4, Tesla V100.
### Training
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
......@@ -122,7 +151,7 @@ export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so
For this example we want to directly store the trained LoRA embeddings on the Hub, so
we need to be logged in and add the `--push_to_hub` flag.
```bash
......@@ -149,7 +178,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \
The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
**Notes**:
**Notes**:
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
......@@ -178,7 +207,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \
### Inference
Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`.
```python
......
......@@ -35,6 +35,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -48,6 +49,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -860,7 +868,7 @@ class StableDiffusionXLPipeline(
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
# 8.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
......@@ -908,6 +916,9 @@ class StableDiffusionXLPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -32,6 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -45,6 +46,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1031,6 +1039,9 @@ class StableDiffusionXLImg2ImgPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -34,6 +34,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -47,6 +48,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1355,6 +1363,9 @@ class StableDiffusionXLInpaintPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -33,6 +33,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
......@@ -44,6 +45,13 @@ from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -926,6 +934,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -73,6 +73,7 @@ from .import_utils import (
is_tensorboard_available,
is_torch_available,
is_torch_version,
is_torch_xla_available,
is_torchsde_available,
is_transformers_available,
is_transformers_version,
......
......@@ -64,6 +64,14 @@ else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
try:
_torch_xla_version = importlib_metadata.version("torch_xla")
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
except ImportError:
_torch_xla_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
......@@ -281,6 +289,10 @@ def is_torch_available():
return _torch_available
def is_torch_xla_available():
return _torch_xla_available
def is_flax_available():
return _flax_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment