Commit 0244e2af authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct diffusion test

parent 3a177754
...@@ -471,7 +471,7 @@ class GradTTSPipeline(DiffusionPipeline): ...@@ -471,7 +471,7 @@ class GradTTSPipeline(DiffusionPipeline):
mu_y = mu_y.transpose(1, 2) mu_y = mu_y.transpose(1, 2)
# Sample latent representation from terminal distribution N(mu_y, I) # Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn(mu_y.shape, device=mu_y.device, generator=generator) / temperature z = mu_y + torch.randn(mu_y.shape, generator=generator).to(mu_y.device)
xt = z * y_mask xt = z * y_mask
h = 1.0 / num_inference_steps h = 1.0 / num_inference_steps
......
...@@ -714,9 +714,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -714,9 +714,9 @@ class PipelineTesterMixin(unittest.TestCase):
assert mel_spec.shape == (1, 80, 143) assert mel_spec.shape == (1, 80, 143)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[-6.6119, -6.5963, -6.2776, -6.7496, -6.7096, -6.5131, -6.4643, -6.4817, -6.7185] [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890]
) )
assert (mel_spec[0, :3, :3].flatten() - expected_slice).abs().max() < 1e-2 assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment