Unverified Commit cc59b056 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)



* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent daddd98b
...@@ -80,7 +80,7 @@ pipe = pipe.to("cuda") ...@@ -80,7 +80,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).images[0]
``` ```
**Note**: If you don't want to use the token, you can also simply download the model weights **Note**: If you don't want to use the token, you can also simply download the model weights
...@@ -101,7 +101,7 @@ pipe = pipe.to("cuda") ...@@ -101,7 +101,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).images[0]
``` ```
If you are limited by GPU memory, you might want to consider using the model in `fp16`. If you are limited by GPU memory, you might want to consider using the model in `fp16`.
...@@ -117,7 +117,7 @@ pipe = pipe.to("cuda") ...@@ -117,7 +117,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).images[0]
``` ```
Finally, if you wish to use a different scheduler, you can simply instantiate Finally, if you wish to use a different scheduler, you can simply instantiate
...@@ -143,7 +143,7 @@ pipe = pipe.to("cuda") ...@@ -143,7 +143,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png") image.save("astronaut_rides_horse.png")
``` ```
...@@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512)) ...@@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation" prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"): with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"] images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png") images[0].save("fantasy_landscape.png")
``` ```
...@@ -228,7 +228,7 @@ pipe = pipe.to(device) ...@@ -228,7 +228,7 @@ pipe = pipe.to(device)
prompt = "a cat sitting on a bench" prompt = "a cat sitting on a bench"
with autocast("cuda"): with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"] images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png") images[0].save("cat_on_bench.png")
``` ```
...@@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id) ...@@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"] images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images
# save images # save images
for idx, image in enumerate(images): for idx, image in enumerate(images):
...@@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256" ...@@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
image = ddpm()["sample"] image = ddpm().images
# save image # save image
image[0].save("ddpm_generated_image.png") image[0].save("ddpm_generated_image.png")
......
...@@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch ...@@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
prompt = "A <cat-toy> backpack" prompt = "A <cat-toy> backpack"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0] image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png") image.save("cat-backpack.png")
``` ```
...@@ -498,7 +498,7 @@ def main(): ...@@ -498,7 +498,7 @@ def main():
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"]).sample().detach() latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215 latents = latents * 0.18215
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
...@@ -515,7 +515,7 @@ def main(): ...@@ -515,7 +515,7 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual # Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"] noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss) accelerator.backward(loss)
......
...@@ -139,7 +139,7 @@ def main(args): ...@@ -139,7 +139,7 @@ def main(args):
with accelerator.accumulate(model): with accelerator.accumulate(model):
# Predict the noise residual # Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"] noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise) loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss) accelerator.backward(loss)
...@@ -174,7 +174,7 @@ def main(args): ...@@ -174,7 +174,7 @@ def main(args):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
# denormalize the images and save to tensorboard # denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8") images_processed = (images * 255).round().astype("uint8")
......
...@@ -119,7 +119,7 @@ for mod in models: ...@@ -119,7 +119,7 @@ for mod in models:
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad(): with torch.no_grad():
logits = model(noise, time_step)["sample"] logits = model(noise, time_step).sample
assert torch.allclose( assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
......
...@@ -19,9 +19,9 @@ import shutil ...@@ -19,9 +19,9 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from .pipeline_utils import DiffusionPipeline
from .utils import is_modelcards_available, logging from .utils import is_modelcards_available, logging
......
from typing import Dict, Optional, Tuple, Union from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states output. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DModel(ModelMixin, ConfigMixin): class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
...@@ -118,8 +131,11 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -118,8 +131,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward( def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] self,
) -> Dict[str, torch.FloatTensor]: sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
...@@ -181,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -181,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps sample = sample / timesteps
output = {"sample": sample} if not return_dict:
return (sample,)
return output return UNet2DOutput(sample=sample)
from typing import Dict, Optional, Tuple, Union from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DConditionModel(ModelMixin, ConfigMixin): class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
...@@ -125,7 +138,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -125,7 +138,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]: return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
...@@ -183,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -183,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
output = {"sample": sample} if not return_dict:
return (sample,)
return output return UNet2DConditionOutput(sample=sample)
from typing import Optional, Tuple from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -6,9 +7,50 @@ import torch.nn as nn ...@@ -6,9 +7,50 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""
sample: torch.FloatTensor
@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.
Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""
latents: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
...@@ -369,12 +411,18 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -369,12 +411,18 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn, act_fn=act_fn,
) )
def encode(self, x): def encode(self, x, return_dict: bool = True):
h = self.encoder(x) h = self.encoder(x)
h = self.quant_conv(h) h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False): if not return_dict:
return (h,)
return VQEncoderOutput(latents=h)
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer # also go through quantization layer
if not force_not_quantize: if not force_not_quantize:
quant, emb_loss, info = self.quantize(h) quant, emb_loss, info = self.quantize(h)
...@@ -382,13 +430,21 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -382,13 +430,21 @@ class VQModel(ModelMixin, ConfigMixin):
quant = h quant = h
quant = self.post_quant_conv(quant) quant = self.post_quant_conv(quant)
dec = self.decoder(quant) dec = self.decoder(quant)
return dec
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample x = sample
h = self.encode(x) h = self.encode(x).latents
dec = self.decode(h) dec = self.decode(h).sample
return dec
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
...@@ -431,23 +487,37 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -431,23 +487,37 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
def encode(self, x): def encode(self, x, return_dict: bool = True):
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z): if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
return dec
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor: if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample x = sample
posterior = self.encode(x) posterior = self.encode(x).latent_dist
if sample_posterior: if sample_posterior:
z = posterior.sample() z = posterior.sample()
else: else:
z = posterior.mode() z = posterior.mode()
dec = self.decode(z) dec = self.decode(z).sample
return dec
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
...@@ -17,16 +17,19 @@ ...@@ -17,16 +17,19 @@
import importlib import importlib
import inspect import inspect
import os import os
from typing import Optional, Union from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch import torch
import PIL
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin" INDEX_FILE = "diffusion_pytorch_model.bin"
...@@ -54,6 +57,20 @@ for library in LOADABLE_CLASSES: ...@@ -54,6 +57,20 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@dataclass
class ImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json" config_name = "model_index.json"
......
...@@ -94,7 +94,7 @@ pipe = pipe.to("cuda") ...@@ -94,7 +94,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png") image.save("astronaut_rides_horse.png")
``` ```
...@@ -130,7 +130,7 @@ init_image = init_image.resize((768, 512)) ...@@ -130,7 +130,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation" prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"): with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"] images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png") images[0].save("fantasy_landscape.png")
``` ```
...@@ -174,7 +174,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained( ...@@ -174,7 +174,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
prompt = "a cat sitting on a bench" prompt = "a cat sitting on a bench"
with autocast("cuda"): with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"] images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png") images[0].save("cat_on_bench.png")
``` ```
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import warnings import warnings
from typing import Tuple, Union
import torch import torch
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DDIMPipeline(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
...@@ -28,7 +29,16 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -28,7 +29,16 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs): def __call__(
self,
batch_size=1,
generator=None,
eta=0.0,
num_inference_steps=50,
output_type="pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") device = kwargs.pop("torch_device")
...@@ -56,15 +66,18 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -56,15 +66,18 @@ class DDIMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta # 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image} if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import warnings import warnings
from typing import Tuple, Union
import torch import torch
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DDPMPipeline(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
...@@ -28,7 +29,9 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -28,7 +29,9 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs): def __call__(
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") device = kwargs.pop("torch_device")
warnings.warn( warnings.warn(
...@@ -53,14 +56,17 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -53,14 +56,17 @@ class DDPMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> t_t-1 # 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"] image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image} if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
...@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput ...@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class LDMTextToImagePipeline(DiffusionPipeline): class LDMTextToImagePipeline(DiffusionPipeline):
...@@ -32,8 +32,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -32,8 +32,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[Tuple, ImagePipelineOutput]:
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs: if "torch_device" in kwargs:
...@@ -95,25 +96,28 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -95,25 +96,28 @@ class LDMTextToImagePipeline(DiffusionPipeline):
context = torch.cat([uncond_embeddings, text_embeddings]) context = torch.cat([uncond_embeddings, text_embeddings])
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
# perform guidance # perform guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vqvae.decode(latents) image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image} if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
################################################################################ ################################################################################
...@@ -525,7 +529,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel): ...@@ -525,7 +529,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail. for more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
......
import inspect import inspect
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel, VQModel from ...models import UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
...@@ -28,8 +28,9 @@ class LDMPipeline(DiffusionPipeline): ...@@ -28,8 +28,9 @@ class LDMPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[Tuple, ImagePipelineOutput]:
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs: if "torch_device" in kwargs:
...@@ -61,16 +62,19 @@ class LDMPipeline(DiffusionPipeline): ...@@ -61,16 +62,19 @@ class LDMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual # predict the noise residual
noise_prediction = self.unet(latents, t)["sample"] noise_prediction = self.unet(latents, t).sample
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"] latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
# decode the image latents with the VAE # decode the image latents with the VAE
image = self.vqvae.decode(latents) image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image} if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
...@@ -40,8 +40,9 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -40,8 +40,9 @@ class PNDMPipeline(DiffusionPipeline):
num_inference_steps: int = 50, num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[ImagePipelineOutput, Tuple]:
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
...@@ -66,13 +67,16 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -66,13 +67,16 @@ class PNDMPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t).sample
image = self.scheduler.step(model_output, t, image)["prev_sample"] image = self.scheduler.step(model_output, t, image).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image} if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from diffusers import DiffusionPipeline
from ...models import UNet2DModel from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import ScoreSdeVeScheduler from ...schedulers import ScoreSdeVeScheduler
...@@ -26,8 +25,9 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -26,8 +25,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
num_inference_steps: int = 2000, num_inference_steps: int = 2000,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") device = kwargs.pop("torch_device")
warnings.warn( warnings.warn(
...@@ -56,18 +56,21 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -56,18 +56,21 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"] model_output = self.unet(sample, sigma_t).sample
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"] sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
# prediction step # prediction step
model_output = model(sample, sigma_t)["sample"] model_output = model(sample, sigma_t).sample
output = self.scheduler.step_pred(model_output, t, sample, generator=generator) output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, sample_mean = output.prev_sample, output.prev_sample_mean
sample = sample_mean.clamp(0, 1) sample = sample_mean.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy() sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
sample = self.numpy_to_pil(sample) sample = self.numpy_to_pil(sample)
return {"sample": sample} if not return_dict:
return (sample,)
return ImagePipelineOutput(images=sample)
...@@ -67,7 +67,7 @@ pipe = pipe.to("cuda") ...@@ -67,7 +67,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png") image.save("astronaut_rides_horse.png")
``` ```
...@@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ...@@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png") image.save("astronaut_rides_horse.png")
``` ```
...@@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ...@@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png") image.save("astronaut_rides_horse.png")
``` ```
# flake8: noqa # flake8: noqa
from ...utils import is_transformers_available from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_transformers_available
@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
if is_transformers_available(): if is_transformers_available():
......
...@@ -9,6 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -9,6 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -47,6 +48,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -47,6 +48,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ):
if "torch_device" in kwargs: if "torch_device" in kwargs:
...@@ -141,7 +143,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -141,7 +143,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -150,13 +152,13 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -150,13 +152,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else: else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents) image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
...@@ -168,4 +170,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -168,4 +170,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept} if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment