Commit 12b10cbe authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish refactor

parent 2d97544d
......@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
check_dirs := models tests src utils
check_dirs := tests src utils
modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
......
......@@ -2,15 +2,16 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
__version__ = "0.0.1"
__version__ = "0.0.3"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
from .schedulers import SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
......@@ -213,7 +213,7 @@ class ConfigMixin:
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warn(
logger.warning(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
......
......@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
logger.warning(
logger.warninging(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
......@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
logger.warninging(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
......@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
logger.warninging(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
......
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
......@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
scale_embedding=False,
use_cache=True,
pad_token_id=0,
**kwargs
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
......
......@@ -2,10 +2,10 @@
import math
import numpy as np
import tqdm
import torch
import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
......@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
......@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end,
)
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x):
......@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
\ No newline at end of file
return dec, posterior
......@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
scale_embedding=False,
use_cache=True,
pad_token_id=0,
**kwargs
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
......
import tqdm
import torch
import tqdm
from diffusers import DiffusionPipeline
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
# add these relative imports here, so we can load from hub
from .modeling_vae import AutoencoderKL # NOQA
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
from .modeling_vae import AutoencoderKL # NOQA
class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
......@@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline):
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
......@@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
......@@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline):
device=torch_device,
generator=generator,
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
......@@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline):
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
......@@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline):
# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
......@@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline):
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
......@@ -43,6 +43,7 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from .configuration_ldmbert import LDMBertConfig
......@@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
def forward(
self,
input_ids=None,
......@@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
):
outputs = self.model(
input_ids,
......@@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel):
sequence_output = outputs[0]
# logits = self.to_logits(sequence_output)
# outputs = (logits,) + outputs[1:]
# if labels is not None:
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
# outputs = (loss,) + outputs
# if not return_dict:
# return outputs
return BaseModelOutput(
last_hidden_state=sequence_output,
# hidden_states=outputs[1],
......
......@@ -2,10 +2,10 @@
import math
import numpy as np
import tqdm
import torch
import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
......@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
......@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end,
)
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x):
......@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
\ No newline at end of file
return dec, posterior
......@@ -17,12 +17,14 @@
import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
......@@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline):
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
image = self.noise_scheduler.sample_noise(
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
......@@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline):
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
......
......@@ -17,12 +17,14 @@
import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None):
......@@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
image = self.noise_scheduler.sample_noise(
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
image = image.to(torch_device)
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
......@@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline):
# 3. optionally sample variance
variance = 0
if t > 0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
# 4. set current image to prev_image: x_t -> x_t-1
......
......@@ -24,10 +24,6 @@ import torch.utils.checkpoint
from torch import nn
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
......@@ -40,6 +36,10 @@ from transformers.utils import (
replace_return_docstrings,
)
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
......
......@@ -2,13 +2,14 @@
import math
import numpy as np
import tqdm
import torch
import torch.nn as nn
from ..pipeline_utils import DiffusionPipeline
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
def get_timestep_embedding(timesteps, embedding_dim):
......@@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
......@@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end,
)
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x):
......@@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
......@@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise(
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
device=torch_device,
generator=generator,
)
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
......@@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline):
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
......@@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline):
# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
......@@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline):
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
noise = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
)
noise = noise.to(torch_device)
prev_image = pred_prev_image + std_dev_t * noise
else:
prev_image = pred_prev_image
......@@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline):
image = prev_image
# scale and decode image with vae
image = 1 / 0.18215 * image
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
......@@ -17,6 +17,7 @@
# limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .ddim import DDIMScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
from .schedulers_utils import SchedulerMixin
......@@ -13,20 +13,13 @@
# limitations under the License.
import math
import torch
from torch import nn
import numpy as np
from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class DDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
class DDIMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
......@@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
beta_end=0.02,
beta_schedule="linear",
clip_predicted_image=True,
tensor_format="np",
):
super().__init__()
self.register(
......@@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin):
self.clip_image = clip_predicted_image
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
self.betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
......@@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.one
return self.alphas_cumprod[time_step]
def get_orig_t(self, t, num_inference_steps):
......@@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin):
# 3. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt()
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.clip_image:
pred_original_image = torch.clamp(pred_original_image, -1, 1)
pred_original_image = self.clip(pred_original_image, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self.get_variance(t, num_inference_steps)
std_dev_t = eta * variance.sqrt()
std_dev_t = eta * variance ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * residual
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
pred_prev_image = alpha_prod_t_prev ** (0.5) * pred_original_image + pred_image_direction
return pred_prev_image
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.timesteps
......@@ -13,19 +13,13 @@
# limitations under the License.
import math
import torch
from torch import nn
import numpy as np
from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
class GaussianDDPMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
......@@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
beta_schedule="linear",
variance_type="fixed_small",
clip_predicted_image=True,
tensor_format="np",
):
super().__init__()
self.register(
......@@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
self.variance_type = variance_type
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
self.betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
......@@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.one
return self.alphas_cumprod[time_step]
def get_variance(self, t):
......@@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t))
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
# hacks - were probs added for training stability
if self.variance_type == "fixed_small":
variance = variance.clamp(min=1e-20)
variance = self.clip(variance, min_value=1e-20)
elif self.variance_type == "fixed_large":
variance = self.get_beta(t)
......@@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
# 2. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt()
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 3. Clip "predicted x_0"
if self.clip_predicted_image:
pred_original_image = torch.clamp(pred_original_image, -1, 1)
pred_original_image = self.clip(pred_original_image, -1, 1)
# 4. Compute coefficients for pred_original_image x_0 and current image x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.get_beta(t)) / beta_prod_t
current_image_coeff = self.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
pred_original_image_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
current_image_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
......@@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return pred_prev_image
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.timesteps
......@@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
......@@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float64)
return np.array(betas, dtype=np.float32)
class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
self.tensor_format = tensor_format
if tensor_format == "pt":
for key, value in vars(self).items():
if isinstance(value, np.ndarray):
setattr(self, key, torch.from_numpy(value))
return self
def clip(self, tensor, min_value=None, max_value=None):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.clip(tensor, min_value, max_value)
elif tensor_format == "pt":
return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
import os
import random
import unittest
import torch
from distutils.util import strtobool
import torch
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment