Unverified Commit 16a056a7 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

Wuerstchen fixes (#4942)



* fix arguments and make example code work

* change arguments in combined test

* Add default timesteps

* style

* fixed test

* fix broken test

* formatting

* fix docstrings

* fix  num_images_per_prompt

* fix doc styles

* please dont change this

* fix tests

* rename to DEFAULT_STAGE_C_TIMESTEPS

---------
Co-authored-by: default avatarDominic Rampas <d6582533@gmail.com>
parent 6c6a2464
......@@ -17,6 +17,7 @@ After the initial paper release, we have improved numerous things in the archite
- Multi Aspect Resolution Sampling
- Better quality
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
- v2-base
......@@ -24,7 +25,7 @@ We are releasing 3 checkpoints for the text-conditional image generation model (
- v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
A comparison can be seen here:
A comparison can be seen here:
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500>
......@@ -35,27 +36,18 @@ For the sake of usability Würstchen can be used with a single pipeline. This pi
```python
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2
pipeline = AutoPipelineForText2Image.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype
).to(device)
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = ""
output = pipeline(
prompt=caption,
height=1024,
images = pipe(
caption,
width=1024,
negative_prompt=negative_prompt,
height=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
decoder_guidance_scale=0.0,
num_images_per_prompt=num_images_per_prompt,
output_type="pil",
num_images_per_prompt=2,
).images
```
......@@ -64,27 +56,29 @@ For explanation purposes, we can also initialize the two main pipelines of Würs
```python
import torch
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2
prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
"warp-diffusion/wuerstchen-prior", torch_dtype=dtype
"warp-ai/wuerstchen-prior", torch_dtype=dtype
).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype
"warp-ai/wuerstchen", torch_dtype=dtype
).to(device)
caption = "A captivating artwork of a mysterious stone golem"
caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = ""
prior_output = prior_pipeline(
prompt=caption,
height=1024,
width=1024,
width=1536,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
negative_prompt=negative_prompt,
guidance_scale=4.0,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder_pipeline(
......@@ -109,13 +103,12 @@ pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullg
- Due to the high compression employed by Würstchen, generations can lack a good amount
of detail. To our human eye, this is especially noticeable in faces, hands etc.
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
after 1024x1024 is 1152x1152
- The model lacks the ability to render correct text in images
- The model often does not achieve photorealism
- Difficult compositional prompts are hard for the model
The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
## WuerschenPipeline
......
......@@ -91,12 +91,12 @@ prior_pipeline = WuerstchenPriorPipeline(
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior")
prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
decoder_pipeline = WuerstchenDecoderPipeline(
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
)
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen")
decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
# Wuerstchen pipeline
wuerstchen_pipeline = WuerstchenCombinedPipeline(
......@@ -112,4 +112,4 @@ wuerstchen_pipeline = WuerstchenCombinedPipeline(
prior=prior_model,
prior_scheduler=scheduler,
)
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline")
wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
......@@ -24,7 +24,7 @@ else:
_import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"]
_import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"]
_import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"]
_import_structure["pipeline_wuerstchen_prior"] = ["WuerstchenPriorPipeline"]
_import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"]
import sys
......
......@@ -35,11 +35,11 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain(
... "warp-diffusion/wuerstchen", torch_dtype=torch.float16
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
......
......@@ -31,9 +31,9 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
```py
>>> from diffusions import WuerstchenCombinedPipeline
>>> pipe = WuerstchenCombinedPipeline.from_pretrained(
... "warp-diffusion/Wuerstchen", torch_dtype=torch.float16
... ).to("cuda")
>>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
```
......@@ -145,16 +145,16 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
guidance_scale: float = 4.0,
num_images_per_prompt: int = 1,
height: int = 512,
width: int = 512,
prior_guidance_scale: float = 4.0,
prior_num_inference_steps: int = 60,
num_inference_steps: int = 12,
prior_timesteps: Optional[List[float]] = None,
timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
......@@ -182,19 +182,20 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
`prior_timesteps`
num_inference_steps (`int`, *optional*, defaults to 12):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps`
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. For more specific timestep spacing, you can pass customized
`timesteps`
prior_timesteps (`List[float]`, *optional*):
Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced
`prior_num_inference_steps` timesteps are used. Must be in descending order.
timesteps (`List[float]`, *optional*):
decoder_timesteps (`List[float]`, *optional*):
Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
`decoder_num_inference_steps` timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
`num_inference_steps` timesteps are used. Must be in descending order.
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
......@@ -221,27 +222,28 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
"""
prior_outputs = self.prior_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_images_per_prompt=num_images_per_prompt,
width=width,
num_inference_steps=prior_num_inference_steps,
timesteps=prior_timesteps,
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
guidance_scale=prior_guidance_scale,
output_type="pt",
return_dict=False,
)
image_embeddings = prior_outputs[0]
outputs = self.decoder_pipe(
prompt=prompt,
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
timesteps=decoder_timesteps,
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt,
generator=generator,
guidance_scale=guidance_scale,
output_type=output_type,
return_dict=return_dict,
)
......
......@@ -14,7 +14,7 @@
from dataclasses import dataclass
from math import ceil
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import numpy as np
import torch
......@@ -35,6 +35,8 @@ from .modeling_wuerstchen_prior import WuerstchenPrior
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
Examples:
```py
......@@ -42,7 +44,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import WuerstchenPriorPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
......@@ -265,7 +267,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
prompt: Union[str, List[str]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 30,
num_inference_steps: int = 60,
timesteps: List[float] = None,
guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
......@@ -274,6 +276,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pt",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
......@@ -314,6 +318,12 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
......@@ -365,7 +375,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]):
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise image embeddings
......@@ -390,6 +400,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
generator=generator,
).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
......
......@@ -38,7 +38,8 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
"height",
"width",
"latents",
"guidance_scale",
"prior_guidance_scale",
"decoder_guidance_scale",
"negative_prompt",
"num_inference_steps",
"return_dict",
......@@ -160,7 +161,7 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
"prompt": "horse",
"generator": generator,
"prior_guidance_scale": 4.0,
"guidance_scale": 4.0,
"decoder_guidance_scale": 4.0,
"num_inference_steps": 2,
"prior_num_inference_steps": 2,
"output_type": "np",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment