Unverified Commit 16a056a7 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

Wuerstchen fixes (#4942)



* fix arguments and make example code work

* change arguments in combined test

* Add default timesteps

* style

* fixed test

* fix broken test

* formatting

* fix docstrings

* fix  num_images_per_prompt

* fix doc styles

* please dont change this

* fix tests

* rename to DEFAULT_STAGE_C_TIMESTEPS

---------
Co-authored-by: default avatarDominic Rampas <d6582533@gmail.com>
parent 6c6a2464
...@@ -17,6 +17,7 @@ After the initial paper release, we have improved numerous things in the archite ...@@ -17,6 +17,7 @@ After the initial paper release, we have improved numerous things in the archite
- Multi Aspect Resolution Sampling - Multi Aspect Resolution Sampling
- Better quality - Better quality
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are: We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
- v2-base - v2-base
...@@ -35,27 +36,18 @@ For the sake of usability Würstchen can be used with a single pipeline. This pi ...@@ -35,27 +36,18 @@ For the sake of usability Würstchen can be used with a single pipeline. This pi
```python ```python
import torch import torch
from diffusers import AutoPipelineForText2Image from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
device = "cuda" pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
dtype = torch.float16
num_images_per_prompt = 2
pipeline = AutoPipelineForText2Image.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype
).to(device)
caption = "Anthropomorphic cat dressed as a fire fighter" caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = "" images = pipe(
caption,
output = pipeline(
prompt=caption,
height=1024,
width=1024, width=1024,
negative_prompt=negative_prompt, height=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0, prior_guidance_scale=4.0,
decoder_guidance_scale=0.0, num_images_per_prompt=2,
num_images_per_prompt=num_images_per_prompt,
output_type="pil",
).images ).images
``` ```
...@@ -64,25 +56,27 @@ For explanation purposes, we can also initialize the two main pipelines of Würs ...@@ -64,25 +56,27 @@ For explanation purposes, we can also initialize the two main pipelines of Würs
```python ```python
import torch import torch
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
device = "cuda" device = "cuda"
dtype = torch.float16 dtype = torch.float16
num_images_per_prompt = 2 num_images_per_prompt = 2
prior_pipeline = WuerstchenPriorPipeline.from_pretrained( prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
"warp-diffusion/wuerstchen-prior", torch_dtype=dtype "warp-ai/wuerstchen-prior", torch_dtype=dtype
).to(device) ).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype "warp-ai/wuerstchen", torch_dtype=dtype
).to(device) ).to(device)
caption = "A captivating artwork of a mysterious stone golem" caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = "" negative_prompt = ""
prior_output = prior_pipeline( prior_output = prior_pipeline(
prompt=caption, prompt=caption,
height=1024, height=1024,
width=1024, width=1536,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
guidance_scale=4.0, guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
...@@ -115,7 +109,6 @@ after 1024x1024 is 1152x1152 ...@@ -115,7 +109,6 @@ after 1024x1024 is 1152x1152
- The model often does not achieve photorealism - The model often does not achieve photorealism
- Difficult compositional prompts are hard for the model - Difficult compositional prompts are hard for the model
The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
## WuerschenPipeline ## WuerschenPipeline
......
...@@ -91,12 +91,12 @@ prior_pipeline = WuerstchenPriorPipeline( ...@@ -91,12 +91,12 @@ prior_pipeline = WuerstchenPriorPipeline(
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
) )
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior") prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
decoder_pipeline = WuerstchenDecoderPipeline( decoder_pipeline = WuerstchenDecoderPipeline(
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
) )
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen") decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
# Wuerstchen pipeline # Wuerstchen pipeline
wuerstchen_pipeline = WuerstchenCombinedPipeline( wuerstchen_pipeline = WuerstchenCombinedPipeline(
...@@ -112,4 +112,4 @@ wuerstchen_pipeline = WuerstchenCombinedPipeline( ...@@ -112,4 +112,4 @@ wuerstchen_pipeline = WuerstchenCombinedPipeline(
prior=prior_model, prior=prior_model,
prior_scheduler=scheduler, prior_scheduler=scheduler,
) )
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline") wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
...@@ -24,7 +24,7 @@ else: ...@@ -24,7 +24,7 @@ else:
_import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"]
_import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"]
_import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"]
_import_structure["pipeline_wuerstchen_prior"] = ["WuerstchenPriorPipeline"] _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"]
import sys import sys
......
...@@ -35,11 +35,11 @@ EXAMPLE_DOC_STRING = """ ...@@ -35,11 +35,11 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16 ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain(
... "warp-diffusion/wuerstchen", torch_dtype=torch.float16
... ).to("cuda") ... ).to("cuda")
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt) >>> prior_output = pipe(prompt)
......
...@@ -31,9 +31,9 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """ ...@@ -31,9 +31,9 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
```py ```py
>>> from diffusions import WuerstchenCombinedPipeline >>> from diffusions import WuerstchenCombinedPipeline
>>> pipe = WuerstchenCombinedPipeline.from_pretrained( >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to(
... "warp-diffusion/Wuerstchen", torch_dtype=torch.float16 ... "cuda"
... ).to("cuda") ... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt) >>> images = pipe(prompt=prompt)
``` ```
...@@ -145,16 +145,16 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -145,16 +145,16 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
guidance_scale: float = 4.0,
num_images_per_prompt: int = 1,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
prior_guidance_scale: float = 4.0,
prior_num_inference_steps: int = 60, prior_num_inference_steps: int = 60,
num_inference_steps: int = 12,
prior_timesteps: Optional[List[float]] = None, prior_timesteps: Optional[List[float]] = None,
timesteps: Optional[List[float]] = None, prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
...@@ -182,19 +182,20 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -182,19 +182,20 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
to the text `prompt`, usually at the expense of lower image quality. to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30): prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized expense of slower inference. For more specific timestep spacing, you can pass customized
`prior_timesteps` `prior_timesteps`
num_inference_steps (`int`, *optional*, defaults to 12): num_inference_steps (`int`, *optional*, defaults to 12):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps` the expense of slower inference. For more specific timestep spacing, you can pass customized
`timesteps`
prior_timesteps (`List[float]`, *optional*): prior_timesteps (`List[float]`, *optional*):
Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced
`prior_num_inference_steps` timesteps are used. Must be in descending order. `prior_num_inference_steps` timesteps are used. Must be in descending order.
timesteps (`List[float]`, *optional*): decoder_timesteps (`List[float]`, *optional*):
Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
`decoder_num_inference_steps` timesteps are used. Must be in descending order. `num_inference_steps` timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0): decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
...@@ -221,27 +222,28 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -221,27 +222,28 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
""" """
prior_outputs = self.prior_pipe( prior_outputs = self.prior_pipe(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height, height=height,
num_images_per_prompt=num_images_per_prompt, width=width,
num_inference_steps=prior_num_inference_steps, num_inference_steps=prior_num_inference_steps,
timesteps=prior_timesteps, timesteps=prior_timesteps,
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator, generator=generator,
latents=latents, latents=latents,
guidance_scale=prior_guidance_scale,
output_type="pt", output_type="pt",
return_dict=False, return_dict=False,
) )
image_embeddings = prior_outputs[0] image_embeddings = prior_outputs[0]
outputs = self.decoder_pipe( outputs = self.decoder_pipe(
prompt=prompt,
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
timesteps=timesteps, timesteps=decoder_timesteps,
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt,
generator=generator, generator=generator,
guidance_scale=guidance_scale,
output_type=output_type, output_type=output_type,
return_dict=return_dict, return_dict=return_dict,
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -35,6 +35,8 @@ from .modeling_wuerstchen_prior import WuerstchenPrior ...@@ -35,6 +35,8 @@ from .modeling_wuerstchen_prior import WuerstchenPrior
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -42,7 +44,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -42,7 +44,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import WuerstchenPriorPipeline >>> from diffusers import WuerstchenPriorPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16 ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda") ... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
...@@ -265,7 +267,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -265,7 +267,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
num_inference_steps: int = 30, num_inference_steps: int = 60,
timesteps: List[float] = None, timesteps: List[float] = None,
guidance_scale: float = 8.0, guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -274,6 +276,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -274,6 +276,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pt", output_type: Optional[str] = "pt",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -314,6 +318,12 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -314,6 +318,12 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples: Examples:
...@@ -365,7 +375,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -365,7 +375,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop # 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]): for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype) ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise image embeddings # 7. Denoise image embeddings
...@@ -390,6 +400,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -390,6 +400,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
generator=generator, generator=generator,
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Denormalize the latents # 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std latents = latents * self.config.latent_mean - self.config.latent_std
......
...@@ -38,7 +38,8 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -38,7 +38,8 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
"height", "height",
"width", "width",
"latents", "latents",
"guidance_scale", "prior_guidance_scale",
"decoder_guidance_scale",
"negative_prompt", "negative_prompt",
"num_inference_steps", "num_inference_steps",
"return_dict", "return_dict",
...@@ -160,7 +161,7 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -160,7 +161,7 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
"prompt": "horse", "prompt": "horse",
"generator": generator, "generator": generator,
"prior_guidance_scale": 4.0, "prior_guidance_scale": 4.0,
"guidance_scale": 4.0, "decoder_guidance_scale": 4.0,
"num_inference_steps": 2, "num_inference_steps": 2,
"prior_num_inference_steps": 2, "prior_num_inference_steps": 2,
"output_type": "np", "output_type": "np",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment