Unverified Commit af8d3b97 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Rename infer method to forward for WaveRNNInferenceWrapper (#1650)

parent 47ccabbf
...@@ -73,18 +73,18 @@ def main(args): ...@@ -73,18 +73,18 @@ def main(args):
mel_specgram = transforms(waveform) mel_specgram = transforms(waveform)
wavernn_model = wavernn(args.checkpoint_name).eval().to(device) wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
wavernn_model = WaveRNNInferenceWrapper(wavernn_model) wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
if args.jit: if args.jit:
wavernn_model = torch.jit.script(wavernn_model) wavernn_inference_model = torch.jit.script(wavernn_inference_model)
with torch.no_grad(): with torch.no_grad():
output = wavernn_model.infer(mel_specgram.to(device), output = wavernn_inference_model(mel_specgram.to(device),
loss_name=args.loss, loss_name=args.loss,
mulaw=(not args.no_mulaw), mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference), batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps, timesteps=args.batch_timesteps,
overlap=args.batch_overlap,) overlap=args.batch_overlap,)
torchaudio.save(args.output_wav_path, output.reshape(1, -1), sample_rate=sample_rate) torchaudio.save(args.output_wav_path, output.reshape(1, -1), sample_rate=sample_rate)
......
...@@ -160,14 +160,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -160,14 +160,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f"Valid choices are 'both', 'before' and 'after'.") f"Valid choices are 'both', 'before' and 'after'.")
return padded return padded
@torch.jit.export def forward(self,
def infer(self, specgram: Tensor,
specgram: Tensor, loss_name: str = "crossentropy",
loss_name: str = "crossentropy", mulaw: bool = True,
mulaw: bool = True, batched: bool = True,
batched: bool = True, timesteps: int = 11000,
timesteps: int = 11000, overlap: int = 550) -> Tensor:
overlap: int = 550) -> Tensor:
r"""Inference function for WaveRNN. r"""Inference function for WaveRNN.
Based on the implementation from Based on the implementation from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment