Commit 498722b5 authored by moto's avatar moto
Browse files

Store n_bits in WaveRNN (#1847)

Move the computation of `#classes -> #bits` to the constructor of WaveRNN and attach it to the instance, so that it can be reused elsewhere.
parent 202bc4f2
...@@ -193,12 +193,10 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -193,12 +193,10 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
if batched: if batched:
specgram = _fold_with_overlap(specgram, timesteps, overlap) specgram = _fold_with_overlap(specgram, timesteps, overlap)
n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))
output = self.wavernn_model.infer(specgram).cpu() output = self.wavernn_model.infer(specgram).cpu()
if mulaw: if mulaw:
output = normalized_waveform_to_bits(output, n_bits) output = normalized_waveform_to_bits(output, self.wavernn_model.n_bits)
output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes) output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes)
if batched: if batched:
......
...@@ -6,6 +6,7 @@ from parameterized import parameterized ...@@ -6,6 +6,7 @@ from parameterized import parameterized
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import torch_script
class TestWav2Letter(common_utils.TorchaudioTestCase): class TestWav2Letter(common_utils.TorchaudioTestCase):
...@@ -145,6 +146,32 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -145,6 +146,32 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1)) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1))
def test_torchscript_infer(self):
"""Scripted model outputs the same as eager mode"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_classes = 512
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
model.eval()
x = torch.rand(n_batch, n_freq, n_time)
torch.random.manual_seed(0)
out_eager = model.infer(x)
torch.random.manual_seed(0)
out_script = torch_script(model).infer(x)
self.assertEqual(out_eager, out_script)
_ConvTasNetParams = namedtuple( _ConvTasNetParams = namedtuple(
'_ConvTasNetParams', '_ConvTasNetParams',
......
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any
import math
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -268,6 +269,7 @@ class WaveRNN(nn.Module): ...@@ -268,6 +269,7 @@ class WaveRNN(nn.Module):
self.n_aux = n_output // 4 self.n_aux = n_output // 4
self.hop_length = hop_length self.hop_length = hop_length
self.n_classes = n_classes self.n_classes = n_classes
self.n_bits: int = int(math.log2(self.n_classes))
total_scale = 1 total_scale = 1
for upsample_scale in upsample_scales: for upsample_scale in upsample_scales:
...@@ -365,8 +367,6 @@ class WaveRNN(nn.Module): ...@@ -365,8 +367,6 @@ class WaveRNN(nn.Module):
device = specgram.device device = specgram.device
dtype = specgram.dtype dtype = specgram.dtype
# make it compatible with torchscript
n_bits = int(torch.log2(torch.ones(1) * self.n_classes))
specgram, aux = self.upsample(specgram) specgram, aux = self.upsample(specgram)
...@@ -406,7 +406,7 @@ class WaveRNN(nn.Module): ...@@ -406,7 +406,7 @@ class WaveRNN(nn.Module):
x = torch.multinomial(posterior, 1).float() x = torch.multinomial(posterior, 1).float()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1] # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x = 2 * x / (2 ** n_bits - 1.0) - 1.0 x = 2 * x / (2 ** self.n_bits - 1.0) - 1.0
output.append(x) output.append(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment