Unverified Commit af8d3b97 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Rename infer method to forward for WaveRNNInferenceWrapper (#1650)

parent 47ccabbf
......@@ -73,18 +73,18 @@ def main(args):
mel_specgram = transforms(waveform)
wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
wavernn_model = WaveRNNInferenceWrapper(wavernn_model)
wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
if args.jit:
wavernn_model = torch.jit.script(wavernn_model)
wavernn_inference_model = torch.jit.script(wavernn_inference_model)
with torch.no_grad():
output = wavernn_model.infer(mel_specgram.to(device),
loss_name=args.loss,
mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps,
overlap=args.batch_overlap,)
output = wavernn_inference_model(mel_specgram.to(device),
loss_name=args.loss,
mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps,
overlap=args.batch_overlap,)
torchaudio.save(args.output_wav_path, output.reshape(1, -1), sample_rate=sample_rate)
......
......@@ -160,14 +160,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f"Valid choices are 'both', 'before' and 'after'.")
return padded
@torch.jit.export
def infer(self,
specgram: Tensor,
loss_name: str = "crossentropy",
mulaw: bool = True,
batched: bool = True,
timesteps: int = 11000,
overlap: int = 550) -> Tensor:
def forward(self,
specgram: Tensor,
loss_name: str = "crossentropy",
mulaw: bool = True,
batched: bool = True,
timesteps: int = 11000,
overlap: int = 550) -> Tensor:
r"""Inference function for WaveRNN.
Based on the implementation from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment