Unverified Commit a17d6d68 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update Cascade documentation (#7257)



* updates

* update

* update

* Update docs/source/en/api/pipelines/stable_cascade.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
parent 8efd9ce7
...@@ -37,6 +37,147 @@ a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while b ...@@ -37,6 +37,147 @@ a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while b
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
for generating the small 24 x 24 latents given a text prompt. for generating the small 24 x 24 latents given a text prompt.
The Stage C model operates on the small 24 x 24 latents and denoises the latents conditioned on text prompts. The model is also the largest component in the Cascade pipeline and is meant to be used with the `StableCascadePriorPipeline`
The Stage B and Stage A models are used with the `StableCascadeDecoderPipeline` and are responsible for generating the final image given the small 24 x 24 latents.
<Tip warning={true}>
There are some restrictions on data types that can be used with the Stable Cascade models. The official checkpoints for the `StableCascadePriorPipeline` do not support the `torch.float16` data type. Please use `torch.bfloat16` instead.
In order to use the `torch.bfloat16` data type with the `StableCascadeDecoderPipeline` you need to have PyTorch 2.2.0 or higher installed. This also means that using the `StableCascadeCombinedPipeline` with `torch.bfloat16` requires PyTorch 2.2.0 or higher, since it calls the `StableCascadeDecoderPipeline` internally.
If it is not possible to install PyTorch 2.2.0 or higher in your environment, the `StableCascadeDecoderPipeline` can be used on its own with the `torch.float16` data type. You can download the full precision or `bf16` variant weights for the pipeline and cast the weights to `torch.float16`.
</Tip>
## Usage example
```python
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.to(torch.float16),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade.png")
```
## Using the Lite Versions of the Stage B and Stage C models
```python
import torch
from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableCascadeUNet,
)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior_lite")
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder_lite")
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade.png")
```
## Loading original checkpoints with `from_single_file`
Loading the original format checkpoints is supported via `from_single_file` method in the StableCascadeUNet.
```python
import torch
from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableCascadeUNet,
)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior_unet = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/resolve/main/stage_c_bf16.safetensors",
torch_dtype=torch.bfloat16
)
decoder_unet = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
torch_dtype=torch.bfloat16
)
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch.bfloat16)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade-single-file.png")
```
## Uses ## Uses
### Direct Use ### Direct Use
......
...@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import replace_example_docstring from ...utils import is_torch_version, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline from .pipeline_stable_cascade import StableCascadeDecoderPipeline
...@@ -29,11 +29,10 @@ from .pipeline_stable_cascade_prior import StableCascadePriorPipeline ...@@ -29,11 +29,10 @@ from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
TEXT2IMAGE_EXAMPLE_DOC_STRING = """ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
>>> from diffusions import StableCascadeCombinedPipeline >>> import torch
>>> from diffusers import StableCascadeCombinedPipeline
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to( >>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16)
... "cuda" >>> pipe.enable_model_cpu_offload()
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt) >>> images = pipe(prompt=prompt)
``` ```
...@@ -260,6 +259,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -260,6 +259,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
dtype = self.decoder_pipe.decoder.dtype
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
raise ValueError(
"`StableCascadeCombinedPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype."
)
prior_outputs = self.prior_pipe( prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None, prompt=prompt if prompt_embeds is None else None,
images=images, images=images,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment