Unverified Commit e885204e authored by moto's avatar moto Committed by GitHub
Browse files

Add TTS bundle/pipelines (#1872)

Future work items:
- length computation of GriffinLim
- better way to make InverseMelScale work in inference_mode
parent 6b8f378b
......@@ -167,6 +167,69 @@ HUBERT_ASR_XLARGE
.. container:: py attribute
.. autodata:: HUBERT_ASR_XLARGE
Tacotron2 Text-To-Speech
------------------------
Tacotron2TTSBundle
~~~~~~~~~~~~~~~~~~
.. autoclass:: Tacotron2TTSBundle
.. automethod:: get_text_processor
.. automethod:: get_tacotron2
.. automethod:: get_vocoder
Tacotron2TTSBundle - TextProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchaudio.pipelines::Tacotron2TTSBundle.TextProcessor
:members: tokens
:special-members: __call__
Tacotron2TTSBundle - Vocoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchaudio.pipelines::Tacotron2TTSBundle.Vocoder
:members: sample_rate
:special-members: __call__
TACOTRON2_WAVERNN_PHONE_LJSPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: TACOTRON2_WAVERNN_PHONE_LJSPEECH
:no-value:
TACOTRON2_WAVERNN_CHAR_LJSPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: TACOTRON2_WAVERNN_CHAR_LJSPEECH
:no-value:
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH
:no-value:
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH
:no-value:
References
......
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {\url{https://keithito.com/LJ-Speech-Dataset/}},
year = {2017}
}
@misc{conneau2020unsupervised,
title={Unsupervised Cross-lingual Representation Learning for Speech Recognition},
author={Alexis Conneau and Alexei Baevski and Ronan Collobert and Abdelrahman Mohamed and Michael Auli},
......
from torchaudio.pipelines import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
import pytest
@pytest.mark.parametrize(
'bundle',
[
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
]
)
def test_tts_models(bundle):
"""Smoke test of TTS pipeline"""
text = "Hello world! Text to Speech!"
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2()
vocoder = bundle.get_vocoder()
processed, lengths = processor(text)
mel_spec, lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(mel_spec, lengths)
......@@ -20,6 +20,13 @@ from ._wav2vec2 import (
HUBERT_ASR_LARGE,
HUBERT_ASR_XLARGE,
)
from ._tts import (
Tacotron2TTSBundle,
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
__all__ = [
'Wav2Vec2Bundle',
......@@ -42,4 +49,9 @@ __all__ = [
'HUBERT_XLARGE',
'HUBERT_ASR_LARGE',
'HUBERT_ASR_XLARGE',
'Tacotron2TTSBundle',
'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH',
'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH',
'TACOTRON2_WAVERNN_CHAR_LJSPEECH',
'TACOTRON2_WAVERNN_PHONE_LJSPEECH',
]
from .interface import Tacotron2TTSBundle
from .impl import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
__all__ = [
'Tacotron2TTSBundle',
'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH',
'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH',
'TACOTRON2_WAVERNN_CHAR_LJSPEECH',
'TACOTRON2_WAVERNN_PHONE_LJSPEECH',
]
from dataclasses import dataclass
import re
from typing import Union, Optional, Dict, Any, Tuple, List
import torch
from torch import Tensor
from torch.hub import load_state_dict_from_url
from torchaudio.models import Tacotron2, WaveRNN
from torchaudio.functional import mu_law_decoding
from torchaudio.transforms import InverseMelScale, GriffinLim
from . import utils
from .interface import Tacotron2TTSBundle
__all__ = []
_BASE_URL = 'https://download.pytorch.org/torchaudio/models'
################################################################################
# Pipeline implementation - Text Processor
################################################################################
class _EnglishCharProcessor(Tacotron2TTSBundle.TextProcessor):
def __init__(self):
super().__init__()
self._tokens = utils._get_chars()
self._mapping = {s: i for i, s in enumerate(self._tokens)}
@property
def tokens(self):
return self._tokens
def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
if isinstance(texts, str):
texts = [texts]
indices = [[self._mapping[c] for c in t.lower() if c in self._mapping] for t in texts]
return utils._to_tensor(indices)
class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
def __init__(self, *, dl_kwargs=None):
super().__init__()
self._tokens = utils._get_phones()
self._mapping = {p: i for i, p in enumerate(self._tokens)}
self._phonemizer = utils._load_phonemizer(
'en_us_cmudict_forward.pt', dl_kwargs=dl_kwargs)
self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])"
@property
def tokens(self):
return self._tokens
def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
if isinstance(texts, str):
texts = [texts]
indices = []
for phones in self._phonemizer(texts, lang='en_us'):
# '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!']
ret = [re.sub(r'[\[\]]', '', r) for r in re.findall(self._pattern, phones)]
indices.append([self._mapping[p] for p in ret])
return utils._to_tensor(indices)
################################################################################
# Pipeline implementation - Vocoder
################################################################################
class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def __init__(
self,
model: WaveRNN,
min_level_db: Optional[float] = -100
):
super().__init__()
self._sample_rate = 22050
self._model = model
self._min_level_db = min_level_db
@property
def sample_rate(self):
return self._sample_rate
def forward(self, mel_spec, lengths):
mel_spec = torch.exp(mel_spec)
mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
if self._min_level_db is not None:
mel_spec = (self._min_level_db - mel_spec) / self._min_level_db
mel_spec = torch.clamp(mel_spec, min=0, max=1)
waveform, lengths = self._model.infer(mel_spec, lengths)
waveform = utils._unnormalize_waveform(waveform, self._model.n_bits)
waveform = mu_law_decoding(waveform, self._model.n_classes)
waveform = waveform.squeeze(1)
return waveform, lengths
class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def __init__(self):
super().__init__()
self._sample_rate = 22050
self._inv_mel = InverseMelScale(
n_stft=(1024 // 2 + 1),
n_mels=80,
sample_rate=self.sample_rate,
f_min=0.,
f_max=8000.,
mel_scale="slaney",
norm='slaney',
)
self._griffin_lim = GriffinLim(
n_fft=1024,
power=1,
hop_length=256,
win_length=1024,
)
@property
def sample_rate(self):
return self._sample_rate
def forward(self, mel_spec, lengths):
mel_spec = torch.exp(mel_spec)
mel_spec = mel_spec.clone().detach().requires_grad_(True)
spec = self._inv_mel(mel_spec)
spec = spec.detach().requires_grad_(False)
waveforms = self._griffin_lim(spec)
return waveforms, lengths
################################################################################
# Bundle classes mixins
################################################################################
class _CharMixin:
def get_text_processor(self) -> Tacotron2TTSBundle.TextProcessor:
return _EnglishCharProcessor()
class _PhoneMixin:
def get_text_processor(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.TextProcessor:
return _EnglishPhoneProcessor(dl_kwargs=dl_kwargs)
@dataclass
class _Tacotron2Mixin:
_tacotron2_path: str
_tacotron2_params: Dict[str, Any]
def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
model = Tacotron2(**self._tacotron2_params)
url = f'{_BASE_URL}/{self._tacotron2_path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict)
model.eval()
return model
@dataclass
class _WaveRNNMixin:
_wavernn_path: Optional[str]
_wavernn_params: Optional[Dict[str, Any]]
def get_vocoder(self, *, dl_kwargs=None):
wavernn = self._get_wavernn(dl_kwargs=dl_kwargs)
return _WaveRNNVocoder(wavernn)
def _get_wavernn(self, *, dl_kwargs=None):
model = WaveRNN(**self._wavernn_params)
url = f'{_BASE_URL}/{self._wavernn_path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict)
model.eval()
return model
class _GriffinLimMixin:
def get_vocoder(self, **_):
return _GriffinLimVocoder()
################################################################################
# Bundle classes
################################################################################
@dataclass
class _Tacotron2WaveRNNCharBundle(_WaveRNNMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
pass
@dataclass
class _Tacotron2WaveRNNPhoneBundle(_WaveRNNMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
pass
@dataclass
class _Tacotron2GriffinLimCharBundle(_GriffinLimMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
pass
@dataclass
class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
pass
################################################################################
# Instantiate bundle objects
################################################################################
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle(
_tacotron2_path='tacotron2_english_characters_1500_epochs_ljspeech.pth',
_tacotron2_params=utils._get_taco_params(n_symbols=38),
)
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = (
'''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.transforms.GriffinLim`.
The text processor encodes the input texts character-by-character.
Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs.
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
The default parameters were used.
The vocoder is based on :py:class:`torchaudio.transforms.GriffinLim`.
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
Example - "Hello world! T T S stands for Text to Speech!"
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.png
:alt: Spectrogram generated by Tacotron2
.. raw:: html
<audio controls="controls">
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle(
_tacotron2_path='tacotron2_english_phonemes_1500_epochs_ljspeech.pth',
_tacotron2_params=utils._get_taco_params(n_symbols=96),
)
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = (
'''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.transforms.GriffinLim`.
The text processor encodes the input texts based on phoneme.
It uses `DeepPhonemizer <https://github.com/as-ideas/DeepPhonemizer>`__ to convert
graphemes to phonemes.
The model (*en_us_cmudict_forward*) was trained on
`CMUDict <http://www.speech.cs.cmu.edu/cgi-bin/cmudict>`__.
Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs.
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
The text processor is set to the *"english_phonemes"*.
The vocoder is based on :py:class:`torchaudio.transforms.GriffinLim`.
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
Example - "Hello world! T T S stands for Text to Speech!"
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.png
:alt: Spectrogram generated by Tacotron2
.. raw:: html
<audio controls="controls">
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle(
_tacotron2_path='tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth',
_tacotron2_params=utils._get_taco_params(n_symbols=38),
_wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth',
_wavernn_params=utils._get_wrnn_params(),
)
TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = (
'''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.models.WaveRNN`.
The text processor encodes the input texts character-by-character.
Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs.
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
``mel_fmin=40``, and ``mel_fmax=11025``.
The vocder is based on :py:class:`torchaudio.models.WaveRNN`.
It was trained on 8 bits depth waveform of *LJSpeech* [:footcite:`ljspeech17`] for 10,000 epochs.
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_wavernn>`__.
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
Example - "Hello world! T T S stands for Text to Speech!"
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.png
:alt: Spectrogram generated by Tacotron2
.. raw:: html
<audio controls="controls">
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle(
_tacotron2_path='tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth',
_tacotron2_params=utils._get_taco_params(n_symbols=96),
_wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth',
_wavernn_params=utils._get_wrnn_params(),
)
TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = (
'''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and
:py:class:`torchaudio.models.WaveRNN`.
The text processor encodes the input texts based on phoneme.
It uses `DeepPhonemizer <https://github.com/as-ideas/DeepPhonemizer>`__ to convert
graphemes to phonemes.
The model (*en_us_cmudict_forward*) was trained on
`CMUDict <http://www.speech.cs.cmu.edu/cgi-bin/cmudict>`__.
Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs.
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
``mel_fmin=40``, and ``mel_fmax=11025``.
The vocder is based on :py:class:`torchaudio.models.WaveRNN`.
It was trained on 8 bits depth waveform of *LJSpeech* [:footcite:`ljspeech17`] for 10,000 epochs.
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_wavernn>`__.
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
Example - "Hello world! T T S stands for Text to Speech!"
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.png
:alt: Spectrogram generated by Tacotron2
.. raw:: html
<audio controls="controls">
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.wav" type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
''') # noqa: E501
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Optional
from torch import Tensor
from torchaudio.models import Tacotron2
class _TextProcessor(ABC):
"""Interface of the text processing part of Tacotron2TTS pipeline"""
@property
@abstractmethod
def tokens(self):
"""The tokens that the each value in the processed tensor represent.
:type: List[str]
"""
@abstractmethod
def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
"""Encode the given (batch of) texts into numerical tensors
Args:
text (str or list of str): The input texts.
Returns:
Tensor and Tensor:
Tensor:
The encoded texts. Shape: `(batch, max length)`
Tensor:
The valid length of each sample in the batch. Shape: `(batch, )`.
"""
class _Vocoder(ABC):
"""Interface of the vocoder part of Tacotron2TTS pipeline"""
@property
@abstractmethod
def sample_rate(self):
"""The sample rate of the resulting waveform
:type: float
"""
@abstractmethod
def __call__(self, specgrams: Tensor, lengths: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
"""Generate waveform from the given input, such as spectrogram
Args:
specgrams (Tensor):
The input spectrogram. Shape: `(batch, frequency bins, time)`.
The expected shape depends on the implementation.
lengths (Tensor, or None, optional):
The valid length of each sample in the batch. Shape: `(batch, )`.
Returns:
Tensor and optional Tensor:
Tensor:
The generated waveform. Shape: `(batch, max length)`
Tensor or None:
The valid length of each sample in the batch. Shape: `(batch, )`.
"""
class Tacotron2TTSBundle(ABC):
"""Data class that bundles associated information to use pretrained Tacotron2 and vocoder.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
to be used with the model.
Torchaudio library instantiates objects of this class, each of which represents
a different pretrained model. Client code should access pretrained models via these
instances.
Please see below for the usage and the available values.
Example - Character-based TTS pipeline with Tacotron2 and WaveRNN
>>> import torchaudio
>>>
>>> text = "Hello, T T S !"
>>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
>>>
>>> # Build processor, Tacotron2 and WaveRNN model
>>> processor = bundle.get_text_processor()
>>> tacotron2 = bundle.get_tacotron2()
Downloading:
100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
>>> vocoder = bundle.get_vocoder()
Downloading:
100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
>>>
>>> # Encode text
>>> input, lengths = processor(text)
>>>
>>> # Generate (mel-scale) spectrogram
>>> specgram, lengths, _ = tacotron2.infer(input, lengths)
>>>
>>> # Convert spectrogram to waveform
>>> waveforms, lengths = vocoder(specgram, lengths)
>>>
>>> torchaudio.save('hello-tts.wav', waveforms[0], vocoder.sample_rate)
Example - Phoneme-based TTS pipeline with Tacotron2 and WaveRNN
>>>
>>> # Note:
>>> # This bundle uses pre-trained DeepPhonemizer as
>>> # the text pre-processor.
>>> # Please install deep-phonemizer.
>>> # See https://github.com/as-ideas/DeepPhonemizer
>>> # The pretrained weight is automatically downloaded.
>>>
>>> import torchaudio
>>>
>>> text = "Hello, TTS!"
>>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONEME_LJSPEECH
>>>
>>> # Build processor, Tacotron2 and WaveRNN model
>>> processor = bundle.get_text_processor()
Downloading:
100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
>>> tacotron2 = bundle.get_tacotron2()
Downloading:
100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
>>> vocoder = bundle.get_vocoder()
Downloading:
100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
>>>
>>> # Encode text
>>> input, lengths = processor(text)
>>>
>>> # Generate (mel-scale) spectrogram
>>> specgram, lengths, _ = tacotron2.infer(input, lengths)
>>>
>>> # Convert spectrogram to waveform
>>> waveforms, lengths = vocoder(specgram, lengths)
>>>
>>> torchaudio.save('hello-tts.wav', waveforms[0], vocoder.sample_rate)
"""
# Using the inner class so that these interfaces are not directly exposed on
# `torchaudio.pipelines`, but still listed in documentation.
# The thing is, text processing and vocoder are generic and we do not know what kind of
# new text processing and vocoder will be added in the future, so we want to make these
# interfaces specific to this Tacotron2TTS pipeline.
class TextProcessor(_TextProcessor):
pass
class Vocoder(_Vocoder):
pass
@abstractmethod
def get_text_processor(self, *, dl_kwargs=None) -> TextProcessor:
"""get_text_processor(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.TextProcessor:
Create a text processor
For character-based pipeline, this processor splits the input text by character.
For phoneme-based pipeline, this processor converts the input text (grapheme) to
phonemes.
If a pre-trained weight file is necessary,
:func:`torch.hub.download_url_to_file` is used to downloaded it.
Args:
dl_kwargs (dictionary of keyword arguments,):
Passed to :func:`torch.hub.download_url_to_file`.
Returns:
TTSTextProcessor:
A callable which takes a string or a list of strings as input and
returns Tensor of encoded texts and Tensor of valid lengths.
The object also has ``tokens`` property, which allows to recover the
tokenized form.
Example - Character-based
>>> text = [
>>> "Hello, T T S !",
>>> "Text-to-speech!",
>>> ]
>>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
>>> processor = bundle.get_text_processor()
>>> input, lengths = processor(text)
>>>
>>> print(input)
tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 0, 0, 0],
[31, 16, 35, 31, 1, 31, 26, 1, 30, 27, 16, 16, 14, 19, 2]],
dtype=torch.int32)
>>>
>>> print(lengths)
tensor([12, 15], dtype=torch.int32)
>>>
>>> print([processor.tokens[i] for i in input[0]])
['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', '_', '_', '_']
>>> print([processor.tokens[i] for i in input[1]])
['t', 'e', 'x', 't', '-', 't', 'o', '-', 's', 'p', 'e', 'e', 'c', 'h', '!']
Example - Phoneme-based
>>> text = [
>>> "Hello, T T S !",
>>> "Text-to-speech!",
>>> ]
>>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
>>> processor = bundle.get_text_processor()
Downloading:
100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
>>> input, lengths = processor(text)
>>>
>>> print(input)
tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38, 2, 0, 0, 0, 0],
[81, 40, 64, 79, 81, 1, 81, 20, 1, 79, 77, 59, 37, 2]],
dtype=torch.int32)
>>>
>>> print(lengths)
tensor([10, 14], dtype=torch.int32)
>>>
>>> print([processor.tokens[i] for i in input[0]])
['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', '_', '_', '_', '_']
>>> print([processor.tokens[i] for i in input[1]])
['T', 'EH', 'K', 'S', 'T', '-', 'T', 'AH', '-', 'S', 'P', 'IY', 'CH', '!']
"""
@abstractmethod
def get_vocoder(self, *, dl_kwargs=None) -> Vocoder:
"""get_vocoder(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.Vocoder:
Create a vocoder module, based off of either WaveRNN or GriffinLim.
If a pre-trained weight file is necessary,
:func:`torch.hub.load_state_dict_from_url` is used to downloaded it.
Args:
dl_kwargs (dictionary of keyword arguments):
Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Callable[[Tensor, Optional[Tensor]], Tuple[Tensor, Optional[Tensor]]]:
A vocoder module, which takes spectrogram Tensor and an optional
length Tensor, then returns resulting waveform Tensor and an optional
length Tensor.
"""
@abstractmethod
def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
"""Create a Tacotron2 model with pre-trained weight.
Args:
dl_kwargs (dictionary of keyword arguments):
Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Tacotron2:
The resulting model.
"""
import os
import logging
import torch
from torchaudio._internal import module_utils as _mod_utils
def _get_chars():
return (
'_',
'-',
'!',
"'",
'(',
')',
',',
'.',
':',
';',
'?',
' ',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
)
def _get_phones():
return (
"_",
"-",
"!",
"'",
"(",
")",
",",
".",
":",
";",
"?",
" ",
"AA",
"AA0",
"AA1",
"AA2",
"AE",
"AE0",
"AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH"
)
def _to_tensor(indices):
lengths = torch.tensor([len(i) for i in indices], dtype=torch.int32)
values = [torch.tensor(i) for i in indices]
values = torch.nn.utils.rnn.pad_sequence(values, batch_first=True)
return values, lengths
def _load_phonemizer(file, dl_kwargs):
if not _mod_utils.is_module_available('dp'):
raise RuntimeError('DeepPhonemizer is not installed. Please install it.')
from dp.phonemizer import Phonemizer
# By default, dp issues DEBUG level log.
logger = logging.getLogger('dp')
orig_level = logger.level
logger.setLevel(logging.INFO)
try:
url = f'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}'
path = os.path.join(torch.hub.get_dir(), 'checkpoints', file)
if not os.path.exists(path):
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
torch.hub.download_url_to_file(url, path, **dl_kwargs)
return Phonemizer.from_checkpoint(path)
finally:
logger.setLevel(orig_level)
def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
waveform = torch.clamp(waveform, -1, 1)
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
return torch.clamp(waveform, 0, 2 ** bits - 1).int()
def _get_taco_params(n_symbols):
return {
'mask_padding': False,
'n_mels': 80,
'n_frames_per_step': 1,
'symbol_embedding_dim': 512,
'encoder_embedding_dim': 512,
'encoder_n_convolution': 3,
'encoder_kernel_size': 5,
'decoder_rnn_dim': 1024,
'decoder_max_step': 2000,
'decoder_dropout': 0.1,
'decoder_early_stopping': True,
'attention_rnn_dim': 1024,
'attention_hidden_dim': 128,
'attention_location_n_filter': 32,
'attention_location_kernel_size': 31,
'attention_dropout': 0.1,
'prenet_dim': 256,
'postnet_n_convolution': 5,
'postnet_kernel_size': 5,
'postnet_embedding_dim': 512,
'gate_threshold': 0.5,
'n_symbol': n_symbols,
}
def _get_wrnn_params():
return {
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment