Commit 77c80489 authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

parents bff9746d 86da45bc
...@@ -9,6 +9,6 @@ from .models.unet import UNetModel ...@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
...@@ -2,3 +2,4 @@ from .pipeline_ddim import DDIM ...@@ -2,3 +2,4 @@ from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDMPipeline
...@@ -17,6 +17,9 @@ import numpy as np ...@@ -17,6 +17,9 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from ..pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
...@@ -234,3 +237,45 @@ class DiffWave(nn.Module): ...@@ -234,3 +237,45 @@ class DiffWave(nn.Module):
x = self.init_conv(x).clone() x = self.init_conv(x).clone()
x = self.residual_layer((x, mel_spectrogram, diffusion_steps)) x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
return self.final_conv(x) return self.final_conv(x)
class BDDMPipeline(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
with torch.no_grad():
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave(audio, mel_spectrogram, t)
# 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
# 4. set current audio to prev_audio: x_t -> x_t-1
audio = pred_prev_audio + variance
return audio
\ No newline at end of file
...@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_predicted_image=True, clip_predicted_image=True,
tensor_format="np", tensor_format="np",
): ):
...@@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
if beta_schedule == "linear": if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment