Unverified Commit af8d3b97 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Rename infer method to forward for WaveRNNInferenceWrapper (#1650)

parent 47ccabbf
...@@ -73,13 +73,13 @@ def main(args): ...@@ -73,13 +73,13 @@ def main(args):
mel_specgram = transforms(waveform) mel_specgram = transforms(waveform)
wavernn_model = wavernn(args.checkpoint_name).eval().to(device) wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
wavernn_model = WaveRNNInferenceWrapper(wavernn_model) wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
if args.jit: if args.jit:
wavernn_model = torch.jit.script(wavernn_model) wavernn_inference_model = torch.jit.script(wavernn_inference_model)
with torch.no_grad(): with torch.no_grad():
output = wavernn_model.infer(mel_specgram.to(device), output = wavernn_inference_model(mel_specgram.to(device),
loss_name=args.loss, loss_name=args.loss,
mulaw=(not args.no_mulaw), mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference), batched=(not args.no_batch_inference),
......
...@@ -160,8 +160,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -160,8 +160,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f"Valid choices are 'both', 'before' and 'after'.") f"Valid choices are 'both', 'before' and 'after'.")
return padded return padded
@torch.jit.export def forward(self,
def infer(self,
specgram: Tensor, specgram: Tensor,
loss_name: str = "crossentropy", loss_name: str = "crossentropy",
mulaw: bool = True, mulaw: bool = True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment