Commit a82d2592 authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	src/diffusers/__init__.py
#	src/diffusers/pipelines/__init__.py
#	src/diffusers/schedulers/scheduling_ddim.py
parents ba21735c 61dc11c7
...@@ -9,6 +9,6 @@ from .models.unet import UNetModel ...@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, BDDMPipeline, LatentDiffusion from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM
...@@ -271,20 +271,21 @@ class DiffWave(ModelMixin, ConfigMixin): ...@@ -271,20 +271,21 @@ class DiffWave(ModelMixin, ConfigMixin):
return self.final_conv(x) return self.final_conv(x)
class BDDMPipeline(DiffusionPipeline): class BDDM(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler): def __init__(self, diffwave, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler) self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, mel_spectrogram, generator): def __call__(self, mel_spectrogram, generator, torch_device=None):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device) self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length) audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
...@@ -294,9 +295,8 @@ class BDDMPipeline(DiffusionPipeline): ...@@ -294,9 +295,8 @@ class BDDMPipeline(DiffusionPipeline):
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device) residual = self.diffwave((audio, mel_spectrogram, ts))
residual = self.diffwave(audio, mel_spectrogram, t)
# 2. predict previous mean of audio x_t-1 # 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t) pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
......
...@@ -42,9 +42,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -42,9 +42,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
if trained_betas is not None: if beta_schedule == "linear":
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
......
...@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
trained_betas=None,
timestep_values=None,
variance_type="fixed_small", variance_type="fixed_small",
clip_predicted_image=True, clip_predicted_image=True,
tensor_format="np", tensor_format="np",
...@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type, variance_type=variance_type,
clip_predicted_image=clip_predicted_image, clip_predicted_image=clip_predicted_image,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
self.variance_type = variance_type self.variance_type = variance_type
if beta_schedule == "linear": if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment