Unverified Commit f7549730 authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Remove underscore of wavernn model (#810)



* Remove underscore of model name
Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent 33f762f6
......@@ -13,7 +13,7 @@ from torch import nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models._wavernn import _WaveRNN
from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_ljspeech
from losses import LongCrossEntropyLoss, MoLLoss
......@@ -297,7 +297,7 @@ def main(args):
n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30
model = _WaveRNN(
model = WaveRNN(
upsample_scales=args.upsample_scales,
n_classes=n_classes,
hop_length=args.hop_length,
......
import torch
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN
from torchaudio.models import Wav2Letter, MelResNet, UpsampleNetwork, WaveRNN
from . import common_utils
......@@ -36,7 +36,7 @@ class TestWav2Letter(common_utils.TorchaudioTestCase):
class TestMelResNet(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a _MelResNet block.
"""Validate the output dimensions of a MelResNet block.
"""
n_batch = 2
......@@ -47,7 +47,7 @@ class TestMelResNet(common_utils.TorchaudioTestCase):
n_hidden = 128
kernel_size = 5
model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
x = torch.rand(n_batch, n_freq, n_time)
out = model(x)
......@@ -58,7 +58,7 @@ class TestMelResNet(common_utils.TorchaudioTestCase):
class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a _UpsampleNetwork block.
"""Validate the output dimensions of a UpsampleNetwork block.
"""
upsample_scales = [5, 5, 8]
......@@ -74,12 +74,12 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
model = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
model = UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x)
......@@ -91,7 +91,7 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
class TestWaveRNN(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model.
"""Validate the output dimensions of a WaveRNN model.
"""
upsample_scales = [5, 5, 8]
......@@ -107,8 +107,8 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 128
kernel_size = 5
model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
......
from .wav2letter import *
from ._wavernn import *
from .wavernn import *
......@@ -4,10 +4,10 @@ import torch
from torch import Tensor
from torch import nn
__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"]
__all__ = ["ResBlock", "MelResNet", "Stretch2d", "UpsampleNetwork", "WaveRNN"]
class _ResBlock(nn.Module):
class ResBlock(nn.Module):
r"""ResNet block based on "Deep Residual Learning for Image Recognition"
The paper link is https://arxiv.org/pdf/1512.03385.pdf.
......@@ -16,7 +16,7 @@ class _ResBlock(nn.Module):
n_freq: the number of bins in a spectrogram. (Default: ``128``)
Examples
>>> resblock = _ResBlock()
>>> resblock = ResBlock()
>>> input = torch.rand(10, 128, 512) # a random spectrogram
>>> output = resblock(input) # shape: (10, 128, 512)
"""
......@@ -33,9 +33,9 @@ class _ResBlock(nn.Module):
)
def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _ResBlock layer.
r"""Pass the input through the ResBlock layer.
Args:
specgram (Tensor): the input sequence to the _ResBlock layer (n_batch, n_freq, n_time).
specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
Return:
Tensor shape: (n_batch, n_freq, n_time)
......@@ -44,7 +44,7 @@ class _ResBlock(nn.Module):
return self.resblock_model(specgram) + specgram
class _MelResNet(nn.Module):
class MelResNet(nn.Module):
r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
Args:
......@@ -55,7 +55,7 @@ class _MelResNet(nn.Module):
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
Examples
>>> melresnet = _MelResNet()
>>> melresnet = MelResNet()
>>> input = torch.rand(10, 128, 512) # a random spectrogram
>>> output = melresnet(input) # shape: (10, 128, 508)
"""
......@@ -68,7 +68,7 @@ class _MelResNet(nn.Module):
kernel_size: int = 5) -> None:
super().__init__()
ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)]
ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
self.melresnet_model = nn.Sequential(
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
......@@ -79,9 +79,9 @@ class _MelResNet(nn.Module):
)
def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _MelResNet layer.
r"""Pass the input through the MelResNet layer.
Args:
specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time).
specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
Return:
Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
......@@ -90,7 +90,7 @@ class _MelResNet(nn.Module):
return self.melresnet_model(specgram)
class _Stretch2d(nn.Module):
class Stretch2d(nn.Module):
r"""Upscale the frequency and time dimensions of a spectrogram.
Args:
......@@ -98,7 +98,7 @@ class _Stretch2d(nn.Module):
freq_scale: the scale factor in frequency dimension
Examples
>>> stretch2d = _Stretch2d(time_scale=10, freq_scale=5)
>>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
>>> input = torch.rand(10, 100, 512) # a random spectrogram
>>> output = stretch2d(input) # shape: (10, 500, 5120)
......@@ -113,10 +113,10 @@ class _Stretch2d(nn.Module):
self.time_scale = time_scale
def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _Stretch2d layer.
r"""Pass the input through the Stretch2d layer.
Args:
specgram (Tensor): the input sequence to the _Stretch2d layer (..., n_freq, n_time).
specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
Return:
Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
......@@ -125,7 +125,7 @@ class _Stretch2d(nn.Module):
return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
class _UpsampleNetwork(nn.Module):
class UpsampleNetwork(nn.Module):
r"""Upscale the dimensions of a spectrogram.
Args:
......@@ -137,7 +137,7 @@ class _UpsampleNetwork(nn.Module):
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
Examples
>>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16])
>>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
>>> input = torch.rand(10, 128, 10) # a random spectrogram
>>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128)
"""
......@@ -156,12 +156,12 @@ class _UpsampleNetwork(nn.Module):
total_scale *= upsample_scale
self.indent = (kernel_size - 1) // 2 * total_scale
self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.resnet_stretch = _Stretch2d(total_scale, 1)
self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.resnet_stretch = Stretch2d(total_scale, 1)
up_layers = []
for scale in upsample_scales:
stretch = _Stretch2d(scale, 1)
stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(in_channels=1,
out_channels=1,
kernel_size=(1, scale * 2 + 1),
......@@ -173,10 +173,10 @@ class _UpsampleNetwork(nn.Module):
self.upsample_layers = nn.Sequential(*up_layers)
def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the _UpsampleNetwork layer.
r"""Pass the input through the UpsampleNetwork layer.
Args:
specgram (Tensor): the input sequence to the _UpsampleNetwork layer (n_batch, n_freq, n_time)
specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
Return:
Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
......@@ -195,7 +195,7 @@ class _UpsampleNetwork(nn.Module):
return upsampling_output, resnet_output
class _WaveRNN(nn.Module):
class WaveRNN(nn.Module):
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.
The original implementation was introduced in
......@@ -216,7 +216,7 @@ class _WaveRNN(nn.Module):
n_output: the number of output dimensions of melresnet. (Default: ``128``)
Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
......@@ -249,12 +249,12 @@ class _WaveRNN(nn.Module):
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
self.upsample = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.upsample = UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
......@@ -268,11 +268,11 @@ class _WaveRNN(nn.Module):
self.fc3 = nn.Linear(n_fc, self.n_classes)
def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
r"""Pass the input through the _WaveRNN model.
r"""Pass the input through the WaveRNN model.
Args:
waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)
waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment