Commit cdf58a4e authored by patil-suraj's avatar patil-suraj
Browse files

fix BDDMPipeline

parent b96c6ce1
...@@ -269,13 +269,14 @@ class BDDMPipeline(DiffusionPipeline): ...@@ -269,13 +269,14 @@ class BDDMPipeline(DiffusionPipeline):
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler) self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, mel_spectrogram, generator): def __call__(self, mel_spectrogram, generator, torch_device=None):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device) self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length) audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
...@@ -285,9 +286,8 @@ class BDDMPipeline(DiffusionPipeline): ...@@ -285,9 +286,8 @@ class BDDMPipeline(DiffusionPipeline):
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device) residual = self.diffwave((audio, mel_spectrogram, ts))
residual = self.diffwave(audio, mel_spectrogram, t)
# 2. predict previous mean of audio x_t-1 # 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t) pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment