Unverified Commit 8ec6b873 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add pretrained weights for wavernn (#1612)

parent 394d617e
......@@ -88,8 +88,15 @@ WaveRNN
.. automethod:: forward
Factory Functions
-----------------
wavernn
-------
.. autofunction:: wavernn
References
~~~~~~~~~~
.. footbibliography::
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .wavernn import WaveRNN, wavernn
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .wav2vec2 import (
......@@ -13,6 +13,7 @@ from .wav2vec2 import (
__all__ = [
'Wav2Letter',
'WaveRNN',
'wavernn',
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
......
from typing import List, Tuple
from typing import List, Tuple, Dict, Any
import torch
from torch import Tensor
from torch import nn
from torch.hub import load_state_dict_from_url
__all__ = [
"ResBlock",
......@@ -10,9 +12,29 @@ __all__ = [
"Stretch2d",
"UpsampleNetwork",
"WaveRNN",
"wavernn",
]
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'wavernn_10k_epochs_8bits_ljspeech': (
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
{
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
)
}
class ResBlock(nn.Module):
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
......@@ -324,3 +346,28 @@ class WaveRNN(nn.Module):
# bring back channel dimension
return x.unsqueeze(1)
def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"wavernn_10k_epochs_8bits_ljspeech"``:
WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the
`examples/pipeline_wavernn/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_wavernn>`_.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = WaveRNN(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment