Commit b09b152f authored by anton-l's avatar anton-l
Browse files

Merge branch 'main' of github.com:huggingface/diffusers

parents a2117cb7 4497e78d
# References
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)
\ No newline at end of file
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512,
intermediate_size=2048,
num_hidden_layers=16,
num_attention_heads=8,
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx]
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the Text-to-Image UNet
text2im_model = GLIDETextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
transformer_dim=512,
)
text2im_model.load_state_dict(state_dict, strict=False)
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base")
This diff is collapsed.
import torch
import PIL.Image
from diffusers import DiffusionPipeline
generator = torch.Generator()
generator = generator.manual_seed(0)
model_id = "fusing/glide-base"
# load model and scheduler
pipeline = DiffusionPipeline.from_pretrained(model_id)
# run inference (text-conditioned denoising + upscaling)
img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL
img = img.squeeze(0)
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img)
# save image
image_pil.save("test.png")
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LDMBERT model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
}
class LDMBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LDMBertModel`]. It is used to instantiate a
LDMBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the LDMBERT
[facebook/ldmbert-large](https://huggingface.co/facebook/ldmbert-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the LDMBERT model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LDMBertModel`] or [`TFLDMBertModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels: (`int`, *optional*, defaults to 3):
The number of labels to use in [`LDMBertForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import LDMBertModel, LDMBertConfig
>>> # Initializing a LDMBERT facebook/ldmbert-large style configuration
>>> configuration = LDMBertConfig()
>>> # Initializing a model from the facebook/ldmbert-large style configuration
>>> model = LDMBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "ldmbert"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=30522,
max_position_embeddings=77,
encoder_layers=32,
encoder_ffn_dim=5120,
encoder_attention_heads=8,
head_dim=64,
encoder_layerdrop=0.0,
activation_function="gelu",
d_model=1280,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
use_cache=True,
pad_token_id=0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.head_dim = head_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(pad_token_id=pad_token_id, **kwargs)
import torch
import tqdm
from diffusers import DiffusionPipeline
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
# add these relative imports here, so we can load from hub
from .modeling_vae import AutoencoderKL # NOQA
class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
device=torch_device,
generator=generator,
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# guidance_scale of 1 means no guidance
if guidance_scale == 1.0:
image_in = image
context = text_embedding
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline): ...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values timestep_values = self.noise_scheduler.config.timestep_values
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
......
...@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline): ...@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
......
...@@ -24,17 +24,11 @@ import torch.utils.checkpoint ...@@ -24,17 +24,11 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import tqdm import tqdm
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
try: from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
except:
print("Transformers is not installed")
pass
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
......
...@@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline): ...@@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline):
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1) time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, y_mask, mu_y, t, speaker_id) residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
xt = xt * y_mask xt = xt * y_mask
......
...@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn( image = torch.randn(
......
...@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline): ...@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(warmup_time_steps))): for t in tqdm.tqdm(range(len(prk_time_steps))):
t_orig = warmup_time_steps[t] t_orig = prk_time_steps[t]
residual = self.unet(image, t_orig) residual = self.unet(image, t_orig)
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
......
...@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
timesteps=timesteps, timesteps=timesteps,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps)
if beta_schedule == "squaredcos_cap_v2": if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
...@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,12 +11,40 @@ ...@@ -11,12 +11,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math import math
import numpy as np import numpy as np
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
class DDIMScheduler(SchedulerMixin, ConfigMixin): class DDIMScheduler(SchedulerMixin, ConfigMixin):
...@@ -37,19 +65,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -37,19 +65,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
clip_sample=clip_sample,
) )
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
self.betas = betas_for_alpha_bar( self.betas = betas_for_alpha_bar(timesteps)
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -59,51 +84,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -59,51 +84,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
# def rescale_betas(self, num_timesteps):
# # GLIDE scaling
# if self.beta_schedule == "linear":
# scale = self.timesteps / num_timesteps
# self.betas = linear_beta_schedule(
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
# )
# self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return self.one
return self.alphas_cumprod[time_step]
def get_orig_t(self, t, num_inference_steps):
if t < 0:
return -1
return self.timesteps // num_inference_steps * t
def get_variance(self, t, num_inference_steps): def get_variance(self, t, num_inference_steps):
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.config.timesteps // num_inference_steps * t
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps) orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
alpha_prod_t = self.get_alpha_prod(orig_t) alpha_prod_t = self.alphas_cumprod[orig_t]
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t) alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -124,12 +110,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -124,12 +110,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_prev_sample -> "x_t-1" # - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1 # 1. get actual t and t-1
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.config.timesteps // num_inference_steps * t
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps) orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(orig_t) alpha_prod_t = self.alphas_cumprod[orig_t]
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t) alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
...@@ -137,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
...@@ -158,4 +144,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -158,4 +144,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,12 +11,39 @@ ...@@ -11,12 +11,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import numpy as np import numpy as np
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
class DDPMScheduler(SchedulerMixin, ConfigMixin): class DDPMScheduler(SchedulerMixin, ConfigMixin):
...@@ -43,21 +70,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -43,21 +70,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type=variance_type, variance_type=variance_type,
clip_sample=clip_sample, clip_sample=clip_sample,
) )
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
self.variance_type = variance_type
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
self.betas = betas_for_alpha_bar( self.betas = betas_for_alpha_bar(timesteps)
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -67,70 +87,48 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -67,70 +87,48 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return self.one
return self.alphas_cumprod[time_step]
def get_variance(self, t): def get_variance(self, t):
alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample # and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t) variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
# hacks - were probs added for training stability # hacks - were probs added for training stability
if self.variance_type == "fixed_small": if self.config.variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20) variance = self.clip(variance, min_value=1e-20)
elif self.variance_type == "fixed_large": # for rl-diffuser https://arxiv.org/abs/2205.09991
variance = self.get_beta(t) elif self.config.variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20))
elif self.config.variance_type == "fixed_large":
variance = self.betas[t]
return variance return variance
def step(self, residual, sample, t): def step(self, residual, sample, t, predict_epsilon=True):
# 1. compute alphas, betas # 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
else:
pred_original_sample = residual
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t # 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
...@@ -139,10 +137,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,10 +137,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def forward_step(self, original_sample, noise, t): def forward_step(self, original_sample, noise, t):
sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5 sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_sample return noisy_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment