Unverified Commit cc59b056 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)



* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent daddd98b
......@@ -80,7 +80,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```
**Note**: If you don't want to use the token, you can also simply download the model weights
......@@ -101,7 +101,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```
If you are limited by GPU memory, you might want to consider using the model in `fp16`.
......@@ -117,7 +117,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```
Finally, if you wish to use a different scheduler, you can simply instantiate
......@@ -143,7 +143,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
......@@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
......@@ -228,7 +228,7 @@ pipe = pipe.to(device)
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
......@@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)
# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images
# save images
for idx, image in enumerate(images):
......@@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
# run pipeline in inference (sample random noise and denoise)
image = ddpm()["sample"]
image = ddpm().images
# save image
image[0].save("ddpm_generated_image.png")
......
......@@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
prompt = "A <cat-toy> backpack"
with autocast("cuda"):
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
```
......@@ -498,7 +498,7 @@ def main():
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).sample().detach()
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
......@@ -515,7 +515,7 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
......
......@@ -139,7 +139,7 @@ def main(args):
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"]
noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
......@@ -174,7 +174,7 @@ def main(args):
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
......
......@@ -119,7 +119,7 @@ for mod in models:
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)["sample"]
logits = model(noise, time_step).sample
assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
......
......@@ -19,9 +19,9 @@ import shutil
from pathlib import Path
from typing import Optional
from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami
from .pipeline_utils import DiffusionPipeline
from .utils import is_modelcards_available, logging
......
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states output. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
......@@ -118,8 +131,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
......@@ -181,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample}
if not return_dict:
return (sample,)
return output
return UNet2DOutput(sample=sample)
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
......@@ -125,7 +138,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
......@@ -183,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_act(sample)
sample = self.conv_out(sample)
output = {"sample": sample}
if not return_dict:
return (sample,)
return output
return UNet2DConditionOutput(sample=sample)
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
......@@ -6,9 +7,50 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""
sample: torch.FloatTensor
@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.
Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""
latents: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class Encoder(nn.Module):
def __init__(
self,
......@@ -369,12 +411,18 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn,
)
def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
if not return_dict:
return (h,)
return VQEncoderOutput(latents=h)
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
......@@ -382,13 +430,21 @@ class VQModel(ModelMixin, ConfigMixin):
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
h = self.encode(x)
dec = self.decode(h)
return dec
h = self.encode(x).latents
dec = self.decode(h).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
class AutoencoderKL(ModelMixin, ConfigMixin):
......@@ -431,23 +487,37 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
posterior = self.encode(x)
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
......@@ -17,16 +17,19 @@
import importlib
import inspect
import os
from typing import Optional, Union
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
from .utils import DIFFUSERS_CACHE, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin"
......@@ -54,6 +57,20 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@dataclass
class ImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json"
......
......@@ -94,7 +94,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
......@@ -130,7 +130,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
......@@ -174,7 +174,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
......
......@@ -15,10 +15,11 @@
import warnings
from typing import Tuple, Union
import torch
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DDIMPipeline(DiffusionPipeline):
......@@ -28,7 +29,16 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
def __call__(
self,
batch_size=1,
generator=None,
eta=0.0,
num_inference_steps=50,
output_type="pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
......@@ -56,15 +66,18 @@ class DDIMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
......@@ -15,10 +15,11 @@
import warnings
from typing import Tuple, Union
import torch
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DDPMPipeline(DiffusionPipeline):
......@@ -28,7 +29,9 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
def __call__(
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
......@@ -53,14 +56,17 @@ class DDPMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
......@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class LDMTextToImagePipeline(DiffusionPipeline):
......@@ -32,8 +32,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
# eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs:
......@@ -95,25 +96,28 @@ class LDMTextToImagePipeline(DiffusionPipeline):
context = torch.cat([uncond_embeddings, text_embeddings])
# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vqvae.decode(latents)
image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
################################################################################
......@@ -525,7 +529,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......
import inspect
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler
......@@ -28,8 +28,9 @@ class LDMPipeline(DiffusionPipeline):
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
# eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs:
......@@ -61,16 +62,19 @@ class LDMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
noise_prediction = self.unet(latents, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
# decode the image latents with the VAE
image = self.vqvae.decode(latents)
image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
......@@ -15,12 +15,12 @@
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import PNDMScheduler
......@@ -40,8 +40,9 @@ class PNDMPipeline(DiffusionPipeline):
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[ImagePipelineOutput, Tuple]:
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
......@@ -66,13 +67,16 @@ class PNDMPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t).sample
image = self.scheduler.step(model_output, t, image)["prev_sample"]
image = self.scheduler.step(model_output, t, image).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
#!/usr/bin/env python3
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import ScoreSdeVeScheduler
......@@ -26,8 +25,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
num_inference_steps: int = 2000,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
......@@ -56,18 +56,21 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
model_output = self.unet(sample, sigma_t).sample
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
# prediction step
model_output = model(sample, sigma_t)["sample"]
model_output = model(sample, sigma_t).sample
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
sample, sample_mean = output.prev_sample, output.prev_sample_mean
sample = sample_mean.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
return {"sample": sample}
if not return_dict:
return (sample,)
return ImagePipelineOutput(images=sample)
......@@ -67,7 +67,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
......@@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
......@@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
# flake8: noqa
from ...utils import is_transformers_available
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_transformers_available
@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
if is_transformers_available():
......
......@@ -9,6 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -47,6 +48,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
if "torch_device" in kwargs:
......@@ -141,7 +143,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -150,13 +152,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
......@@ -168,4 +170,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment