Unverified Commit 8ec6b873 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add pretrained weights for wavernn (#1612)

parent 394d617e
...@@ -88,8 +88,15 @@ WaveRNN ...@@ -88,8 +88,15 @@ WaveRNN
.. automethod:: forward .. automethod:: forward
Factory Functions
-----------------
wavernn
-------
.. autofunction:: wavernn
References References
~~~~~~~~~~ ~~~~~~~~~~
.. footbibliography:: .. footbibliography::
from .wav2letter import Wav2Letter from .wav2letter import Wav2Letter
from .wavernn import WaveRNN from .wavernn import WaveRNN, wavernn
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
from .wav2vec2 import ( from .wav2vec2 import (
...@@ -13,6 +13,7 @@ from .wav2vec2 import ( ...@@ -13,6 +13,7 @@ from .wav2vec2 import (
__all__ = [ __all__ = [
'Wav2Letter', 'Wav2Letter',
'WaveRNN', 'WaveRNN',
'wavernn',
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model', 'Wav2Vec2Model',
......
from typing import List, Tuple from typing import List, Tuple, Dict, Any
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.hub import load_state_dict_from_url
__all__ = [ __all__ = [
"ResBlock", "ResBlock",
...@@ -10,9 +12,29 @@ __all__ = [ ...@@ -10,9 +12,29 @@ __all__ = [
"Stretch2d", "Stretch2d",
"UpsampleNetwork", "UpsampleNetwork",
"WaveRNN", "WaveRNN",
"wavernn",
] ]
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'wavernn_10k_epochs_8bits_ljspeech': (
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
{
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
)
}
class ResBlock(nn.Module): class ResBlock(nn.Module):
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`]. r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
...@@ -324,3 +346,28 @@ class WaveRNN(nn.Module): ...@@ -324,3 +346,28 @@ class WaveRNN(nn.Module):
# bring back channel dimension # bring back channel dimension
return x.unsqueeze(1) return x.unsqueeze(1)
def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"wavernn_10k_epochs_8bits_ljspeech"``:
WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the
`examples/pipeline_wavernn/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_wavernn>`_.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = WaveRNN(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment