Unverified Commit 209858ea authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Update variable names in wavernn model (#797)



* Change the name of  n_output and n_hidden

* Replace the mode by n_classes and sample_rate

* Change the definition of n_output and n_hidden
Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent 47eb1e6a
......@@ -74,7 +74,12 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
model = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x)
......@@ -86,14 +91,13 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
class TestWaveRNN(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model in waveform mode.
"""Validate the output dimensions of a _WaveRNN model.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
n_classes = 512
hop_length = 200
n_batch = 2
n_time = 200
......@@ -102,41 +106,12 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'waveform'
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits)
def test_mol(self):
"""Validate the output dimensions of a _WaveRNN model in mol mode.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'mol'
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
......@@ -50,8 +50,8 @@ class _MelResNet(nn.Module):
Args:
n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions of melresnet (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
Examples
......@@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module):
upsample_scales: the list of upsample scales
n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions of melresnet (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
Examples
......@@ -205,31 +205,28 @@ class _WaveRNN(nn.Module):
Args:
upsample_scales: the list of upsample scales
n_bits: the bits of output waveform
sample_rate: the rate of audio dimensions (samples per second)
n_classes: the number of output classes
hop_length: the number of samples between the starts of consecutive frames
n_res_block: the number of ResBlock in stack (default=10)
n_rnn: the dimension of RNN layer (default=512)
n_fc: the dimension of fully connected layer (default=512)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
mode: the mode of waveform in ['waveform', 'mol'] (default='waveform')
n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions of melresnet (default=128)
Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
>>> output = wavernn(waveform, specgram)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
"""
def __init__(self,
upsample_scales: List[int],
n_bits: int,
sample_rate: int,
n_classes: int,
hop_length: int,
n_res_block: int = 10,
n_rnn: int = 512,
......@@ -237,24 +234,14 @@ class _WaveRNN(nn.Module):
kernel_size: int = 5,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
mode: str = 'waveform') -> None:
n_output: int = 128) -> None:
super().__init__()
self.mode = mode
self.kernel_size = kernel_size
if self.mode == 'waveform':
self.n_classes = 2 ** n_bits
elif self.mode == 'mol':
self.n_classes = 30
else:
raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}")
self.n_rnn = n_rnn
self.n_aux = n_output // 4
self.hop_length = hop_length
self.sample_rate = sample_rate
self.n_classes = n_classes
total_scale = 1
for upsample_scale in upsample_scales:
......@@ -262,7 +249,12 @@ class _WaveRNN(nn.Module):
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.upsample = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
......@@ -283,7 +275,7 @@ class _WaveRNN(nn.Module):
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)
Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
"""
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment