Commit 5a3467e6 authored by patil-suraj's avatar patil-suraj
Browse files

add default params for GradTTS

parent e2678275
......@@ -396,7 +396,7 @@ class GradTTS(DiffusionPipeline):
self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
@torch.no_grad()
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -427,7 +427,6 @@ class GradTTS(DiffusionPipeline):
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment