Commit 11631e81 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

merge

parents 13c5a065 b8a67640
...@@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
**Text to Image generation with Latent Diffusion** #### **Text to Image generation with Latent Diffusion**
```python ```python
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
...@@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0])
# save image # save image
image_pil.save("test.png") image_pil.save("test.png")
```
#### **Text to speech with BDDM**
_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/) to load tacotron2 model._
```python
import torch
from diffusers import BDDM, DiffusionPipeline
torch_device = "cuda"
# load the BDDM pipeline
bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder")
# load tacotron2 to get the mel spectograms
tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')
tacotron2 = tacotron2.to(torch_device).eval()
text = "Hello world, I missed you so much."
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')
sequences, lengths = utils.prepare_input_sequence([text])
# generate mel spectograms using text
with torch.no_grad():
mel_spec, _, _ = tacotron2.infer(sequences, lengths)
# generate the speech by passing mel spectograms to BDDM pipeline
generator = torch.manual_seed(0)
audio = bddm(mel_spec, generator, torch_device)
# save generated audio
from scipy.io.wavfile import write as wavwrite
sampling_rate = 22050
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
``` ```
## Library structure: ## Library structure:
``` ```
├── models ├── LICENSE
│   ├── audio ├── Makefile
│   │   └── fastdiff
│   │   ├── modeling_fastdiff.py
│   │   ├── README.md
│   │   └── run_fastdiff.py
│   ├── __init__.py
│   └── vision
│   ├── dalle2
│   │   ├── modeling_dalle2.py
│   │   ├── README.md
│   │   └── run_dalle2.py
│   ├── ddpm
│   │   ├── example.py
│   │   ├── modeling_ddpm.py
│   │   ├── README.md
│   │   └── run_ddpm.py
│   ├── glide
│   │   ├── modeling_glide.py
│   │   ├── modeling_vqvae.py.py
│   │   ├── README.md
│   │   └── run_glide.py
│   ├── imagen
│   │   ├── modeling_dalle2.py
│   │   ├── README.md
│   │   └── run_dalle2.py
│   ├── __init__.py
│   └── latent_diffusion
│   ├── modeling_latent_diffusion.py
│   ├── README.md
│   └── run_latent_diffusion.py
├── pyproject.toml
├── README.md ├── README.md
├── pyproject.toml
├── setup.cfg ├── setup.cfg
├── setup.py ├── setup.py
├── src ├── src
│   └── diffusers │ ├── diffusers
│   ├── configuration_utils.py │ ├── __init__.py
│   ├── __init__.py │ ├── configuration_utils.py
│   ├── modeling_utils.py │ ├── dependency_versions_check.py
│   ├── models │ ├── dependency_versions_table.py
│   │   ├── __init__.py │ ├── dynamic_modules_utils.py
│   │   ├── unet_glide.py │ ├── modeling_utils.py
│   │   └── unet.py │ ├── models
│   ├── pipeline_utils.py │ │ ├── __init__.py
│   └── schedulers │ │ ├── unet.py
│   ├── gaussian_ddpm.py │ │ ├── unet_glide.py
│   ├── __init__.py │ │ └── unet_ldm.py
│ ├── pipeline_utils.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── configuration_ldmbert.py
│ │ ├── conversion_glide.py
│ │ ├── modeling_vae.py
│ │ ├── pipeline_bddm.py
│ │ ├── pipeline_ddim.py
│ │ ├── pipeline_ddpm.py
│ │ ├── pipeline_glide.py
│ │ └── pipeline_latent_diffusion.py
│ ├── schedulers
│ │ ├── __init__.py
│ │ ├── classifier_free_guidance.py
│ │ ├── scheduling_ddim.py
│ │ ├── scheduling_ddpm.py
│ │ ├── scheduling_plms.py
│ │ └── scheduling_utils.py
│ ├── testing_utils.py
│ └── utils
│ ├── __init__.py
│ └── logging.py
├── tests ├── tests
│   └── test_modeling_utils.py │ ├── __init__.py
│ ├── test_modeling_utils.py
│ └── test_scheduler.py
└── utils
├── check_config_docstrings.py
├── check_copies.py
├── check_dummies.py
├── check_inits.py
├── check_repo.py
├── check_table.py
└── check_tf_ops.py
``` ```
...@@ -9,6 +9,6 @@ from .models.unet import UNetModel ...@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
...@@ -225,11 +225,11 @@ class ConfigMixin: ...@@ -225,11 +225,11 @@ class ConfigMixin:
text = reader.read() text = reader.read()
return json.loads(text) return json.loads(text)
def __eq__(self, other): # def __eq__(self, other):
return self.__dict__ == other.__dict__ # return self.__dict__ == other.__dict__
def __repr__(self): # def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" # return f"{self.__class__.__name__} {self.to_json_string()}"
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
......
...@@ -3,3 +3,4 @@ from .pipeline_ddpm import DDPM ...@@ -3,3 +3,4 @@ from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM from .pipeline_pndm import PNDM
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM
...@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel( ...@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False) superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02) upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = GLIDE( glide = GLIDE(
text_unet=text2im_model, text_unet=text2im_model,
......
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
import math import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
""" """
...@@ -41,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): ...@@ -41,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
_embed = np.log(10000) / (half_dim - 1) _embed = np.log(10000) / (half_dim - 1)
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda() _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
_embed = diffusion_steps * _embed _embed = diffusion_steps * _embed
diffusion_step_embed = torch.cat((torch.sin(_embed), diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1)
torch.cos(_embed)), 1)
return diffusion_step_embed return diffusion_step_embed
...@@ -62,8 +68,7 @@ class Conv(nn.Module): ...@@ -62,8 +68,7 @@ class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
super().__init__() super().__init__()
self.padding = dilation * (kernel_size - 1) // 2 self.padding = dilation * (kernel_size - 1) // 2
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
dilation=dilation, padding=self.padding)
self.conv = nn.utils.weight_norm(self.conv) self.conv = nn.utils.weight_norm(self.conv)
nn.init.kaiming_normal_(self.conv.weight) nn.init.kaiming_normal_(self.conv.weight)
...@@ -89,8 +94,7 @@ class ZeroConv1d(nn.Module): ...@@ -89,8 +94,7 @@ class ZeroConv1d(nn.Module):
# every residual block (named residual layer in paper) # every residual block (named residual layer in paper)
# contains one noncausal dilated conv # contains one noncausal dilated conv
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, res_channels, skip_channels, dilation, def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out):
diffusion_step_embed_dim_out):
super().__init__() super().__init__()
self.res_channels = res_channels self.res_channels = res_channels
...@@ -98,15 +102,12 @@ class ResidualBlock(nn.Module): ...@@ -98,15 +102,12 @@ class ResidualBlock(nn.Module):
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels) self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
# Dilated conv layer # Dilated conv layer
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
kernel_size=3, dilation=dilation)
# Add mel spectrogram upsampler and conditioner conv1x1 layer # Add mel spectrogram upsampler and conditioner conv1x1 layer
self.upsample_conv2d = nn.ModuleList() self.upsample_conv2d = nn.ModuleList()
for s in [16, 16]: for s in [16, 16]:
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
padding=(1, s // 2),
stride=(1, s))
conv_trans2d = nn.utils.weight_norm(conv_trans2d) conv_trans2d = nn.utils.weight_norm(conv_trans2d)
nn.init.kaiming_normal_(conv_trans2d.weight) nn.init.kaiming_normal_(conv_trans2d.weight)
self.upsample_conv2d.append(conv_trans2d) self.upsample_conv2d.append(conv_trans2d)
...@@ -152,7 +153,7 @@ class ResidualBlock(nn.Module): ...@@ -152,7 +153,7 @@ class ResidualBlock(nn.Module):
h += mel_spec h += mel_spec
# Gated-tanh nonlinearity # Gated-tanh nonlinearity
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :])
# Residual and skip outputs # Residual and skip outputs
res = self.res_conv(out) res = self.res_conv(out)
...@@ -164,10 +165,16 @@ class ResidualBlock(nn.Module): ...@@ -164,10 +165,16 @@ class ResidualBlock(nn.Module):
class ResidualGroup(nn.Module): class ResidualGroup(nn.Module):
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, def __init__(
diffusion_step_embed_dim_in, self,
diffusion_step_embed_dim_mid, res_channels,
diffusion_step_embed_dim_out): skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out,
):
super().__init__() super().__init__()
self.num_res_layers = num_res_layers self.num_res_layers = num_res_layers
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
...@@ -180,16 +187,19 @@ class ResidualGroup(nn.Module): ...@@ -180,16 +187,19 @@ class ResidualGroup(nn.Module):
self.residual_blocks = nn.ModuleList() self.residual_blocks = nn.ModuleList()
for n in range(self.num_res_layers): for n in range(self.num_res_layers):
self.residual_blocks.append( self.residual_blocks.append(
ResidualBlock(res_channels, skip_channels, ResidualBlock(
dilation=2 ** (n % dilation_cycle), res_channels,
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out)) skip_channels,
dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
)
def forward(self, input_data): def forward(self, input_data):
x, mel_spectrogram, diffusion_steps = input_data x, mel_spectrogram, diffusion_steps = input_data
# Embed diffusion step t # Embed diffusion step t
diffusion_step_embed = calc_diffusion_step_embedding( diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
...@@ -206,27 +216,52 @@ class ResidualGroup(nn.Module): ...@@ -206,27 +216,52 @@ class ResidualGroup(nn.Module):
return skip * math.sqrt(1.0 / self.num_res_layers) return skip * math.sqrt(1.0 / self.num_res_layers)
class DiffWave(nn.Module): class DiffWave(ModelMixin, ConfigMixin):
def __init__(self, in_channels, res_channels, skip_channels, out_channels, def __init__(
num_res_layers, dilation_cycle, self,
diffusion_step_embed_dim_in, in_channels=1,
diffusion_step_embed_dim_mid, res_channels=128,
diffusion_step_embed_dim_out): skip_channels=128,
out_channels=1,
num_res_layers=30,
dilation_cycle=10,
diffusion_step_embed_dim_in=128,
diffusion_step_embed_dim_mid=512,
diffusion_step_embed_dim_out=512,
):
super().__init__() super().__init__()
# register all init arguments with self.register
self.register(
in_channels=in_channels,
res_channels=res_channels,
skip_channels=skip_channels,
out_channels=out_channels,
num_res_layers=num_res_layers,
dilation_cycle=dilation_cycle,
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
# Initial conv1x1 with relu # Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False)) self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers # All residual layers
self.residual_layer = ResidualGroup(res_channels, self.residual_layer = ResidualGroup(
skip_channels, res_channels,
num_res_layers, skip_channels,
dilation_cycle, num_res_layers,
diffusion_step_embed_dim_in, dilation_cycle,
diffusion_step_embed_dim_mid, diffusion_step_embed_dim_in,
diffusion_step_embed_dim_out) diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out,
)
# Final conv1x1 -> relu -> zeroconv1x1 # Final conv1x1 -> relu -> zeroconv1x1
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1), self.final_conv = nn.Sequential(
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels)) Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False),
ZeroConv1d(skip_channels, out_channels),
)
def forward(self, input_data): def forward(self, input_data):
audio, mel_spectrogram, diffusion_steps = input_data audio, mel_spectrogram, diffusion_steps = input_data
...@@ -234,3 +269,45 @@ class DiffWave(nn.Module): ...@@ -234,3 +269,45 @@ class DiffWave(nn.Module):
x = self.init_conv(x).clone() x = self.init_conv(x).clone()
x = self.residual_layer((x, mel_spectrogram, diffusion_steps)) x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
return self.final_conv(x) return self.final_conv(x)
class BDDM(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator, torch_device=None):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave((audio, mel_spectrogram, ts))
# 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
# 4. set current audio to prev_audio: x_t -> x_t-1
audio = pred_prev_audio + variance
return audio
...@@ -28,13 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig ...@@ -28,13 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ( from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -872,31 +866,31 @@ class GLIDE(DiffusionPipeline): ...@@ -872,31 +866,31 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), (
batch_size,
self.upscale_unet.in_channels // 2,
self.upscale_unet.resolution,
self.upscale_unet.resolution,
),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device) * upsample_temp
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf num_trained_timesteps = self.upscale_noise_scheduler.timesteps
# Ideally, read DDIM paper in-detail understanding inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
# adapt the beta schedule to the number of steps
# Notation (<variable name> -> <name in paper> # self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res) model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1) noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_noise_scheduler.step( pred_prev_image = self.upscale_noise_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
) )
# 3. optionally sample variance # 3. optionally sample variance
...@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline): ...@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
image = image.permute(0, 2, 3, 1) image = image.clamp(-1, 1).permute(0, 2, 3, 1)
return image return image
...@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_predicted_image=True, clip_predicted_image=True,
tensor_format="np", tensor_format="np",
): ):
...@@ -37,6 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -37,6 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
if beta_schedule == "linear": if beta_schedule == "linear":
...@@ -69,14 +72,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -69,14 +72,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# #
# self.register_buffer("log_variance", log_variance.to(torch.float32)) # self.register_buffer("log_variance", log_variance.to(torch.float32))
def rescale_betas(self, num_timesteps): # def rescale_betas(self, num_timesteps):
if self.beta_schedule == "linear": # # GLIDE scaling
scale = self.timesteps / num_timesteps # if self.beta_schedule == "linear":
self.betas = linear_beta_schedule( # scale = self.timesteps / num_timesteps
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale # self.betas = linear_beta_schedule(
) # num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
self.alphas = 1.0 - self.betas # )
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) # self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -107,7 +111,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -107,7 +111,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def step(self, residual, image, t, num_inference_steps, eta): def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -141,6 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -141,6 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance = self.get_variance(t, num_inference_steps) variance = self.get_variance(t, num_inference_steps)
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
......
...@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
trained_betas=None,
timestep_values=None,
variance_type="fixed_small", variance_type="fixed_small",
clip_predicted_image=True, clip_predicted_image=True,
tensor_format="np", tensor_format="np",
...@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type, variance_type=variance_type,
clip_predicted_image=clip_predicted_image, clip_predicted_image=clip_predicted_image,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
self.variance_type = variance_type self.variance_type = variance_type
if beta_schedule == "linear": if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
...@@ -56,6 +63,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,6 +63,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
self.one = np.array(1.0) self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
...@@ -131,5 +140,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -131,5 +140,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_image return pred_prev_image
def forward_step(self, original_image, noise, t):
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise
return noisy_image
def __len__(self): def __len__(self):
return self.timesteps return self.timesteps
import random
import numpy as np
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(0)
accelerator = Accelerator(mixed_precision="fp16")
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
num_epochs = 100
batch_size = 8
gradient_accumulation_steps = 8
augmentations = Compose(
[
Resize(64),
CenterCrop(64),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(num_epochs):
model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn_like(clean_images[0]).to(clean_images.device)
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
# eval
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
generator = torch.Generator()
generator = generator.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
pipeline.save_pretrained("./poke-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment