Unverified Commit 209858ea authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Update variable names in wavernn model (#797)



* Change the name of  n_output and n_hidden

* Replace the mode by n_classes and sample_rate

* Change the definition of n_output and n_hidden
Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent 47eb1e6a
...@@ -74,7 +74,12 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase): ...@@ -74,7 +74,12 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
for upsample_scale in upsample_scales: for upsample_scale in upsample_scales:
total_scale *= upsample_scale total_scale *= upsample_scale
model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) model = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
x = torch.rand(n_batch, n_freq, n_time) x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x) out1, out2 = model(x)
...@@ -86,14 +91,13 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase): ...@@ -86,14 +91,13 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
class TestWaveRNN(common_utils.TorchaudioTestCase): class TestWaveRNN(common_utils.TorchaudioTestCase):
def test_waveform(self): def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model in waveform mode. """Validate the output dimensions of a _WaveRNN model.
""" """
upsample_scales = [5, 5, 8] upsample_scales = [5, 5, 8]
n_rnn = 512 n_rnn = 512
n_fc = 512 n_fc = 512
n_bits = 9 n_classes = 512
sample_rate = 24000
hop_length = 200 hop_length = 200
n_batch = 2 n_batch = 2
n_time = 200 n_time = 200
...@@ -102,41 +106,12 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -102,41 +106,12 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_res_block = 10 n_res_block = 10
n_hidden = 128 n_hidden = 128
kernel_size = 5 kernel_size = 5
mode = 'waveform'
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time) mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels) out = model(x, mels)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
def test_mol(self):
"""Validate the output dimensions of a _WaveRNN model in mol mode.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'mol'
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30)
...@@ -50,8 +50,8 @@ class _MelResNet(nn.Module): ...@@ -50,8 +50,8 @@ class _MelResNet(nn.Module):
Args: Args:
n_res_block: the number of ResBlock in stack (default=10) n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128) n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128) n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions (default=128) n_output: the number of output dimensions of melresnet (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5) kernel_size: the number of kernel size in the first Conv1d layer (default=5)
Examples Examples
...@@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module): ...@@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module):
upsample_scales: the list of upsample scales upsample_scales: the list of upsample scales
n_res_block: the number of ResBlock in stack (default=10) n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128) n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128) n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions (default=128) n_output: the number of output dimensions of melresnet (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5) kernel_size: the number of kernel size in the first Conv1d layer (default=5)
Examples Examples
...@@ -205,31 +205,28 @@ class _WaveRNN(nn.Module): ...@@ -205,31 +205,28 @@ class _WaveRNN(nn.Module):
Args: Args:
upsample_scales: the list of upsample scales upsample_scales: the list of upsample scales
n_bits: the bits of output waveform n_classes: the number of output classes
sample_rate: the rate of audio dimensions (samples per second)
hop_length: the number of samples between the starts of consecutive frames hop_length: the number of samples between the starts of consecutive frames
n_res_block: the number of ResBlock in stack (default=10) n_res_block: the number of ResBlock in stack (default=10)
n_rnn: the dimension of RNN layer (default=512) n_rnn: the dimension of RNN layer (default=512)
n_fc: the dimension of fully connected layer (default=512) n_fc: the dimension of fully connected layer (default=512)
kernel_size: the number of kernel size in the first Conv1d layer (default=5) kernel_size: the number of kernel size in the first Conv1d layer (default=5)
n_freq: the number of bins in a spectrogram (default=128) n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128) n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions (default=128) n_output: the number of output dimensions of melresnet (default=128)
mode: the mode of waveform in ['waveform', 'mol'] (default='waveform')
Example Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file) >>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
>>> output = wavernn(waveform, specgram) >>> output = wavernn(waveform, specgram)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
""" """
def __init__(self, def __init__(self,
upsample_scales: List[int], upsample_scales: List[int],
n_bits: int, n_classes: int,
sample_rate: int,
hop_length: int, hop_length: int,
n_res_block: int = 10, n_res_block: int = 10,
n_rnn: int = 512, n_rnn: int = 512,
...@@ -237,24 +234,14 @@ class _WaveRNN(nn.Module): ...@@ -237,24 +234,14 @@ class _WaveRNN(nn.Module):
kernel_size: int = 5, kernel_size: int = 5,
n_freq: int = 128, n_freq: int = 128,
n_hidden: int = 128, n_hidden: int = 128,
n_output: int = 128, n_output: int = 128) -> None:
mode: str = 'waveform') -> None:
super().__init__() super().__init__()
self.mode = mode
self.kernel_size = kernel_size self.kernel_size = kernel_size
if self.mode == 'waveform':
self.n_classes = 2 ** n_bits
elif self.mode == 'mol':
self.n_classes = 30
else:
raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}")
self.n_rnn = n_rnn self.n_rnn = n_rnn
self.n_aux = n_output // 4 self.n_aux = n_output // 4
self.hop_length = hop_length self.hop_length = hop_length
self.sample_rate = sample_rate self.n_classes = n_classes
total_scale = 1 total_scale = 1
for upsample_scale in upsample_scales: for upsample_scale in upsample_scales:
...@@ -262,7 +249,12 @@ class _WaveRNN(nn.Module): ...@@ -262,7 +249,12 @@ class _WaveRNN(nn.Module):
if total_scale != self.hop_length: if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) self.upsample = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
...@@ -283,7 +275,7 @@ class _WaveRNN(nn.Module): ...@@ -283,7 +275,7 @@ class _WaveRNN(nn.Module):
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time) specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)
Return: Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
""" """
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment