Unverified Commit 68cc72da authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Add WaveRNN Model (#735)



* upsamplenetwork

* update variable names

* update variable name

* add wavernn model

* update test

* update format

* update format

* update format

* fix conflicts and add transpose

* import update

* update transpose

* update format

* update docstring

* add n_channel in input

* add comment

* update docstring

* update docstring
Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent ad7f43fe
import torch import torch
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN
from . import common_utils from . import common_utils
...@@ -81,3 +81,62 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase): ...@@ -81,3 +81,62 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1))
assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1))
class TestWaveRNN(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model in waveform mode.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'waveform'
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits)
def test_mol(self):
"""Validate the output dimensions of a _WaveRNN model in mol mode.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'mol'
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30)
from typing import List from typing import List
import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"] __all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"]
class _ResBlock(nn.Module): class _ResBlock(nn.Module):
...@@ -192,3 +193,139 @@ class _UpsampleNetwork(nn.Module): ...@@ -192,3 +193,139 @@ class _UpsampleNetwork(nn.Module):
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]
return upsampling_output, resnet_output return upsampling_output, resnet_output
class _WaveRNN(nn.Module):
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.
The original implementation was introduced in
`"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_.
The input channels of waveform and spectrogram have to be 1. The product of
`upsample_scales` must equal `hop_length`.
Args:
upsample_scales: the list of upsample scales
n_bits: the bits of output waveform
sample_rate: the rate of audio dimensions (samples per second)
hop_length: the number of samples between the starts of consecutive frames
n_res_block: the number of ResBlock in stack (default=10)
n_rnn: the dimension of RNN layer (default=512)
n_fc: the dimension of fully connected layer (default=512)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
mode: the mode of waveform in ['waveform', 'mol'] (default='waveform')
Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
>>> output = wavernn(waveform, specgram)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
"""
def __init__(self,
upsample_scales: List[int],
n_bits: int,
sample_rate: int,
hop_length: int,
n_res_block: int = 10,
n_rnn: int = 512,
n_fc: int = 512,
kernel_size: int = 5,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
mode: str = 'waveform') -> None:
super().__init__()
self.mode = mode
self.kernel_size = kernel_size
if self.mode == 'waveform':
self.n_classes = 2 ** n_bits
elif self.mode == 'mol':
self.n_classes = 30
else:
raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}")
self.n_rnn = n_rnn
self.n_aux = n_output // 4
self.hop_length = hop_length
self.sample_rate = sample_rate
total_scale = 1
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
self.fc3 = nn.Linear(n_fc, self.n_classes)
def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
r"""Pass the input through the _WaveRNN model.
Args:
waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)
Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
"""
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
assert specgram.size(1) == 1, 'Require the input channel of specgram is 1'
# remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
batch_size = waveform.size(0)
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
# output of upsample:
# specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
# aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
specgram, aux = self.upsample(specgram)
specgram = specgram.transpose(1, 2)
aux = aux.transpose(1, 2)
aux_idx = [self.n_aux * i for i in range(5)]
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
x = self.fc(x)
res = x
x, _ = self.rnn1(x, h1)
x = x + res
res = x
x = torch.cat([x, a2], dim=-1)
x, _ = self.rnn2(x, h2)
x = x + res
x = torch.cat([x, a3], dim=-1)
x = self.fc1(x)
x = self.relu1(x)
x = torch.cat([x, a4], dim=-1)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
# bring back channel dimension
return x.unsqueeze(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment