Commit cdf58a4e authored by patil-suraj's avatar patil-suraj
Browse files

fix BDDMPipeline

parent b96c6ce1
......@@ -269,13 +269,14 @@ class BDDMPipeline(DiffusionPipeline):
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator):
def __call__(self, mel_spectrogram, generator, torch_device=None):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop
......@@ -285,9 +286,8 @@ class BDDMPipeline(DiffusionPipeline):
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
with torch.no_grad():
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave(audio, mel_spectrogram, t)
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave((audio, mel_spectrogram, ts))
# 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment