Commit 12b10cbe authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish refactor

parent 2d97544d
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src export PYTHONPATH = src
check_dirs := models tests src utils check_dirs := tests src utils
modified_only_fixup: modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
......
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
__version__ = "0.0.1" __version__ = "0.0.3"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
from .schedulers import SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.ddim import DDIMScheduler from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler from .schedulers.glide_ddim import GlideDDIMScheduler
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
...@@ -213,7 +213,7 @@ class ConfigMixin: ...@@ -213,7 +213,7 @@ class ConfigMixin:
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warning(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
) )
......
...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module): ...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning( logger.warninging(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
...@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module): ...@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
else: else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.warning( logger.warninging(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
...@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module): ...@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
for key, shape1, shape2 in mismatched_keys for key, shape1, shape2 in mismatched_keys
] ]
) )
logger.warning( logger.warninging(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
......
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
...@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig): ...@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
scale_embedding=False, scale_embedding=False,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
**kwargs **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import math import math
import numpy as np import numpy as np
import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object): ...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self): def mode(self):
return self.mean return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
...@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end, give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x): def encode(self, x):
...@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
else: else:
z = posterior.mode() z = posterior.mode()
dec = self.decode(z) dec = self.decode(z)
return dec, posterior return dec, posterior
\ No newline at end of file
...@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig): ...@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
scale_embedding=False, scale_embedding=False,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
**kwargs **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
......
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
# add these relative imports here, so we can load from hub # add these relative imports here, so we can load from hub
from .modeling_vae import AutoencoderKL # NOQA from .modeling_vae import AutoencoderKL # NOQA
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
class LatentDiffusion(DiffusionPipeline): class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
...@@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline):
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50): def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
...@@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
self.bert.to(torch_device) self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence # get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0: if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device) uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)[0] uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding # get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
...@@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline):
device=torch_device, device=torch_device,
generator=generator, generator=generator,
) )
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline):
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else: else:
# for classifier free guidance, we need to do two forward passes # for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch # here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_in = torch.cat([image] * 2) image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding]) context = torch.cat([uncond_embeddings, text_embedding])
...@@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline):
# 1. predict noise residual # 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context) pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance # perform guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
...@@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline):
image = pred_prev_image + variance image = pred_prev_image + variance
# scale and decode image with vae # scale and decode image with vae
image = 1 / 0.18215 * image image = 1 / 0.18215 * image
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image return image
...@@ -43,6 +43,7 @@ from transformers.utils import ( ...@@ -43,6 +43,7 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_ldmbert import LDMBertConfig from .configuration_ldmbert import LDMBertConfig
...@@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel): ...@@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config) super().__init__(config)
self.model = LDMBertEncoder(config) self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel): ...@@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
outputs = self.model( outputs = self.model(
input_ids, input_ids,
...@@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel): ...@@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
# logits = self.to_logits(sequence_output) # logits = self.to_logits(sequence_output)
# outputs = (logits,) + outputs[1:] # outputs = (logits,) + outputs[1:]
# if labels is not None: # if labels is not None:
# loss_fct = CrossEntropyLoss() # loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) # loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
# outputs = (loss,) + outputs # outputs = (loss,) + outputs
# if not return_dict: # if not return_dict:
# return outputs # return outputs
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
# hidden_states=outputs[1], # hidden_states=outputs[1],
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import math import math
import numpy as np import numpy as np
import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object): ...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self): def mode(self):
return self.mean return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
...@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end, give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x): def encode(self, x):
...@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
else: else:
z = posterior.mode() z = posterior.mode()
dec = self.decode(z) dec = self.decode(z)
return dec, posterior return dec, posterior
\ No newline at end of file
...@@ -17,12 +17,14 @@ ...@@ -17,12 +17,14 @@
import torch import torch
import tqdm import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class DDIM(DiffusionPipeline): class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
...@@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline): ...@@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = self.noise_scheduler.sample_noise( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator, generator=generator,
) )
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline): ...@@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline):
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
......
...@@ -17,12 +17,14 @@ ...@@ -17,12 +17,14 @@
import torch import torch
import tqdm import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None): def __call__(self, batch_size=1, generator=None, torch_device=None):
...@@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline): ...@@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = self.noise_scheduler.sample_noise( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator, generator=generator,
) )
image = image.to(torch_device)
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
...@@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline): ...@@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline):
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if t > 0: if t > 0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise variance = self.noise_scheduler.get_variance(t).sqrt() * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
......
...@@ -24,10 +24,6 @@ import torch.utils.checkpoint ...@@ -24,10 +24,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import tqdm import tqdm
from ..pipeline_utils import DiffusionPipeline
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
...@@ -40,6 +36,10 @@ from transformers.utils import ( ...@@ -40,6 +36,10 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
##################### #####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module) # START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
import math import math
import numpy as np import numpy as np
import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..pipeline_utils import DiffusionPipeline import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
...@@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object): ...@@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self): def mode(self):
return self.mean return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
...@@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end, give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x): def encode(self, x):
...@@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
class LatentDiffusion(DiffusionPipeline): class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50): def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
...@@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
self.bert.to(torch_device) self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence # get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0: if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device) uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)[0] uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding # get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
device=torch_device,
generator=generator, generator=generator,
) )
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline):
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else: else:
# for classifier free guidance, we need to do two forward passes # for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch # here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_in = torch.cat([image] * 2) image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding]) context = torch.cat([uncond_embeddings, text_embedding])
...@@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline):
# 1. predict noise residual # 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context) pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance # perform guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. get actual t and t-1 # 2. get actual t and t-1
train_step = inference_step_times[t] train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else -1 prev_train_step = inference_step_times[t - 1] if t > 0 else -1
...@@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline):
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM # Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0: if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
)
noise = noise.to(torch_device)
prev_image = pred_prev_image + std_dev_t * noise prev_image = pred_prev_image + std_dev_t * noise
else: else:
prev_image = pred_prev_image prev_image = pred_prev_image
...@@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline):
image = prev_image image = prev_image
# scale and decode image with vae # scale and decode image with vae
image = 1 / 0.18215 * image image = 1 / 0.18215 * image
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image return image
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# limitations under the License. # limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .ddim import DDIMScheduler from .ddim import DDIMScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler from .glide_ddim import GlideDDIMScheduler
from .schedulers_utils import SchedulerMixin
...@@ -13,20 +13,13 @@ ...@@ -13,20 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
import torch import numpy as np
from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" class DDIMScheduler(SchedulerMixin, ConfigMixin):
class DDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__( def __init__(
self, self,
timesteps=1000, timesteps=1000,
...@@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin): ...@@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
clip_predicted_image=True, clip_predicted_image=True,
tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register(
...@@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin): ...@@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin):
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
if beta_schedule == "linear": if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
betas = betas_for_alpha_bar( self.betas = betas_for_alpha_bar(
timesteps, timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
) )
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas self.alphas = 1.0 - self.betas
alphas_cumprod = torch.cumprod(alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32)) self.set_format(tensor_format=tensor_format)
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # TODO(PVP) - check how much of these is actually necessary!
# TODO(PVP) - check how much of these is actually necessary! # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # if variance_type == "fixed_small":
# if variance_type == "fixed_small": # log_variance = torch.log(variance.clamp(min=1e-20))
# log_variance = torch.log(variance.clamp(min=1e-20)) # elif variance_type == "fixed_large":
# elif variance_type == "fixed_large": # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) #
# #
# # self.register_buffer("log_variance", log_variance.to(torch.float32))
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin): ...@@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
def get_alpha_prod(self, time_step): def get_alpha_prod(self, time_step):
if time_step < 0: if time_step < 0:
return torch.tensor(1.0) return self.one
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def get_orig_t(self, t, num_inference_steps): def get_orig_t(self, t, num_inference_steps):
...@@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin): ...@@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin):
# 3. compute predicted original image from predicted noise also called # 3. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt() pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.clip_image: if self.clip_image:
pred_original_image = torch.clamp(pred_original_image, -1, 1) pred_original_image = self.clip(pred_original_image, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self.get_variance(t, num_inference_steps) variance = self.get_variance(t, num_inference_steps)
std_dev_t = eta * variance.sqrt() std_dev_t = eta * variance ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * residual pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction pred_prev_image = alpha_prod_t_prev ** (0.5) * pred_original_image + pred_image_direction
return pred_prev_image return pred_prev_image
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.timesteps return self.timesteps
...@@ -13,19 +13,13 @@ ...@@ -13,19 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
import torch import numpy as np
from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" class GaussianDDPMScheduler(SchedulerMixin, ConfigMixin):
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__( def __init__(
self, self,
timesteps=1000, timesteps=1000,
...@@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
beta_schedule="linear", beta_schedule="linear",
variance_type="fixed_small", variance_type="fixed_small",
clip_predicted_image=True, clip_predicted_image=True,
tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register(
...@@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
if beta_schedule == "linear": if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
betas = betas_for_alpha_bar( self.betas = betas_for_alpha_bar(
timesteps, timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
) )
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas self.alphas = 1.0 - self.betas
alphas_cumprod = torch.cumprod(alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32)) self.set_format(tensor_format=tensor_format)
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# self.register_buffer("betas", betas.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # self.register_buffer("alphas", alphas.to(torch.float32))
# TODO(PVP) - check how much of these is actually necessary! # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # TODO(PVP) - check how much of these is actually necessary!
# if variance_type == "fixed_small": # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# log_variance = torch.log(variance.clamp(min=1e-20)) # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# elif variance_type == "fixed_large": # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) # if variance_type == "fixed_small":
# # log_variance = torch.log(variance.clamp(min=1e-20))
# # elif variance_type == "fixed_large":
# self.register_buffer("log_variance", log_variance.to(torch.float32)) # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
def get_alpha_prod(self, time_step): def get_alpha_prod(self, time_step):
if time_step < 0: if time_step < 0:
return torch.tensor(1.0) return self.one
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def get_variance(self, t): def get_variance(self, t):
...@@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image # and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)) variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
# hacks - were probs added for training stability # hacks - were probs added for training stability
if self.variance_type == "fixed_small": if self.variance_type == "fixed_small":
variance = variance.clamp(min=1e-20) variance = self.clip(variance, min_value=1e-20)
elif self.variance_type == "fixed_large": elif self.variance_type == "fixed_large":
variance = self.get_beta(t) variance = self.get_beta(t)
...@@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
# 2. compute predicted original image from predicted noise also called # 2. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt() pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.clip_predicted_image: if self.clip_predicted_image:
pred_original_image = torch.clamp(pred_original_image, -1, 1) pred_original_image = self.clip(pred_original_image, -1, 1)
# 4. Compute coefficients for pred_original_image x_0 and current image x_t # 4. Compute coefficients for pred_original_image x_0 and current image x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.get_beta(t)) / beta_prod_t pred_original_image_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
current_image_coeff = self.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t current_image_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous image µ_t # 5. Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
...@@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return pred_prev_image return pred_prev_image
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.timesteps return self.timesteps
...@@ -11,11 +11,15 @@ ...@@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import torch import torch
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end): def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
...@@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): ...@@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
t1 = i / num_diffusion_timesteps t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float64) return np.array(betas, dtype=np.float32)
class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
self.tensor_format = tensor_format
if tensor_format == "pt":
for key, value in vars(self).items():
if isinstance(value, np.ndarray):
setattr(self, key, torch.from_numpy(value))
return self
def clip(self, tensor, min_value=None, max_value=None):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.clip(tensor, min_value, max_value)
elif tensor_format == "pt":
return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
import os import os
import random import random
import unittest import unittest
import torch
from distutils.util import strtobool from distutils.util import strtobool
import torch
global_rng = random.Random() global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment