tacotron2_pipeline_tutorial.py 10.6 KB
Newer Older
moto's avatar
moto committed
1
"""
moto's avatar
moto committed
2
3
Text-to-Speech with Tacotron2
=============================
moto's avatar
moto committed
4
5
6
7
8
9
10
11
12

**Author** `Yao-Yuan Yang <https://github.com/yangarbiter>`__,
`Moto Hira <moto@fb.com>`__

"""

######################################################################
# Overview
# --------
13
#
moto's avatar
moto committed
14
15
# This tutorial shows how to build text-to-speech pipeline, using the
# pretrained Tacotron2 in torchaudio.
16
#
moto's avatar
moto committed
17
# The text-to-speech pipeline goes as follows:
18
#
moto's avatar
moto committed
19
# 1. Text preprocessing
20
#
moto's avatar
moto committed
21
22
#    First, the input text is encoded into a list of symbols. In this
#    tutorial, we will use English characters and phonemes as the symbols.
23
#
moto's avatar
moto committed
24
# 2. Spectrogram generation
25
#
moto's avatar
moto committed
26
27
#    From the encoded text, a spectrogram is generated. We use ``Tacotron2``
#    model for this.
28
#
moto's avatar
moto committed
29
# 3. Time-domain conversion
30
#
moto's avatar
moto committed
31
32
33
34
35
36
37
#    The last step is converting the spectrogram into the waveform. The
#    process to generate speech from spectrogram is also called Vocoder.
#    In this tutorial, three different vocoders are used,
#    `WaveRNN <https://pytorch.org/audio/stable/models/wavernn.html>`__,
#    `Griffin-Lim <https://pytorch.org/audio/stable/transforms.html#griffinlim>`__,
#    and
#    `Nvidia's WaveGlow <https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/>`__.
38
39
#
#
moto's avatar
moto committed
40
# The following figure illustrates the whole process.
41
#
moto's avatar
moto committed
42
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png
43
#
moto's avatar
moto committed
44
45
46
47
48
49
# All the related components are bundled in :py:func:`torchaudio.pipelines.Tacotron2TTSBundle`,
# but this tutorial will also cover the process under the hood.

######################################################################
# Preparation
# -----------
50
#
moto's avatar
moto committed
51
52
53
# First, we install the necessary dependencies. In addition to
# ``torchaudio``, ``DeepPhonemizer`` is required to perform phoneme-based
# encoding.
54
#
moto's avatar
moto committed
55
56
57
58
59
60
61
62
63
64
65

# When running this example in notebook, install DeepPhonemizer
# !pip3 install deep_phonemizer

import torch
import torchaudio
import matplotlib
import matplotlib.pyplot as plt

import IPython

66
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
moto's avatar
moto committed
67
68
69
70
71
72
73
74
75
76
77
78

torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

print(torch.__version__)
print(torchaudio.__version__)
print(device)


######################################################################
# Text Processing
# ---------------
79
#
moto's avatar
moto committed
80
81
82
83
84


######################################################################
# Character-based encoding
# ~~~~~~~~~~~~~~~~~~~~~~~~
85
#
moto's avatar
moto committed
86
87
# In this section, we will go through how the character-based encoding
# works.
88
#
moto's avatar
moto committed
89
90
91
# Since the pre-trained Tacotron2 model expects specific set of symbol
# tables, the same functionalities available in ``torchaudio``. This
# section is more for the explanation of the basis of encoding.
92
#
moto's avatar
moto committed
93
94
95
96
# Firstly, we define the set of symbols. For example, we can use
# ``'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'``. Then, we will map the
# each character of the input text into the index of the corresponding
# symbol in the table.
97
#
moto's avatar
moto committed
98
99
# The following is an example of such processing. In the example, symbols
# that are not in the table are ignored.
100
#
moto's avatar
moto committed
101

102
symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"
moto's avatar
moto committed
103
104
105
look_up = {s: i for i, s in enumerate(symbols)}
symbols = set(symbols)

106

moto's avatar
moto committed
107
def text_to_sequence(text):
108
109
110
    text = text.lower()
    return [look_up[s] for s in text if s in symbols]

moto's avatar
moto committed
111
112
113
114
115
116
117
118
119
120

text = "Hello world! Text to speech!"
print(text_to_sequence(text))


######################################################################
# As mentioned in the above, the symbol table and indices must match
# what the pretrained Tacotron2 model expects. ``torchaudio`` provides the
# transform along with the pretrained model. For example, you can
# instantiate and use such transform as follow.
121
#
moto's avatar
moto committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()

text = "Hello world! Text to speech!"
processed, lengths = processor(text)

print(processed)
print(lengths)


######################################################################
# The ``processor`` object takes either a text or list of texts as inputs.
# When a list of texts are provided, the returned ``lengths`` variable
# represents the valid length of each processed tokens in the output
# batch.
137
#
moto's avatar
moto committed
138
# The intermediate representation can be retrieved as follow.
139
#
moto's avatar
moto committed
140

141
print([processor.tokens[i] for i in processed[0, : lengths[0]]])
moto's avatar
moto committed
142
143
144
145
146


######################################################################
# Phoneme-based encoding
# ~~~~~~~~~~~~~~~~~~~~~~
147
#
moto's avatar
moto committed
148
149
150
# Phoneme-based encoding is similar to character-based encoding, but it
# uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme)
# model.
151
#
moto's avatar
moto committed
152
153
# The detail of the G2P model is out of scope of this tutorial, we will
# just look at what the conversion looks like.
154
#
moto's avatar
moto committed
155
156
157
# Similar to the case of character-based encoding, the encoding process is
# expected to match what a pretrained Tacotron2 model is trained on.
# ``torchaudio`` has an interface to create the process.
158
#
moto's avatar
moto committed
159
160
161
162
# The following code illustrates how to make and use the process. Behind
# the scene, a G2P model is created using ``DeepPhonemizer`` package, and
# the pretrained weights published by the author of ``DeepPhonemizer`` is
# fetched.
163
#
moto's avatar
moto committed
164
165
166
167
168
169
170

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH

processor = bundle.get_text_processor()

text = "Hello world! Text to speech!"
with torch.inference_mode():
171
    processed, lengths = processor(text)
moto's avatar
moto committed
172
173
174
175
176
177
178
179

print(processed)
print(lengths)


######################################################################
# Notice that the encoded values are different from the example of
# character-based encoding.
180
#
moto's avatar
moto committed
181
# The intermediate representation looks like the following.
182
#
moto's avatar
moto committed
183

184
print([processor.tokens[i] for i in processed[0, : lengths[0]]])
moto's avatar
moto committed
185
186
187
188
189


######################################################################
# Spectrogram Generation
# ----------------------
190
#
moto's avatar
moto committed
191
192
193
# ``Tacotron2`` is the model we use to generate spectrogram from the
# encoded text. For the detail of the model, please refer to `the
# paper <https://arxiv.org/abs/1712.05884>`__.
194
#
moto's avatar
moto committed
195
196
197
# It is easy to instantiate a Tacotron2 model with pretrained weight,
# however, note that the input to Tacotron2 models need to be processed
# by the matching text processor.
198
#
moto's avatar
moto committed
199
200
# :py:func:`torchaudio.pipelines.Tacotron2TTSBundle` bundles the matching
# models and processors together so that it is easy to create the pipeline.
201
#
moto's avatar
moto committed
202
# For the available bundles, and its usage, please refer to :py:mod:`torchaudio.pipelines`.
203
#
moto's avatar
moto committed
204
205
206
207
208
209
210
211

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)

text = "Hello world! Text to speech!"

with torch.inference_mode():
212
213
214
215
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, _, _ = tacotron2.infer(processed, lengths)
moto's avatar
moto committed
216
217
218
219
220
221
222
223


plt.imshow(spec[0].cpu().detach())


######################################################################
# Note that ``Tacotron2.infer`` method perfoms multinomial sampling,
# therefor, the process of generating the spectrogram incurs randomness.
224
#
moto's avatar
moto committed
225
226
227

fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3))
for i in range(3):
228
229
230
231
    with torch.inference_mode():
        spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
    print(spec[0].shape)
    ax[i].imshow(spec[0].cpu().detach())
moto's avatar
moto committed
232
233
234
235
236
237
plt.show()


######################################################################
# Waveform Generation
# -------------------
238
#
moto's avatar
moto committed
239
240
# Once the spectrogram is generated, the last process is to recover the
# waveform from the spectrogram.
241
#
moto's avatar
moto committed
242
243
# ``torchaudio`` provides vocoders based on ``GriffinLim`` and
# ``WaveRNN``.
244
#
moto's avatar
moto committed
245
246
247
248
249


######################################################################
# WaveRNN
# ~~~~~~~
250
#
moto's avatar
moto committed
251
252
# Continuing from the previous section, we can instantiate the matching
# WaveRNN model from the same bundle.
253
#
moto's avatar
moto committed
254
255
256
257
258
259
260
261
262
263

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH

processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)

text = "Hello world! Text to speech!"

with torch.inference_mode():
264
265
266
267
268
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
    waveforms, lengths = vocoder(spec, spec_lengths)
moto's avatar
moto committed
269
270
271
272
273

fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())

274
275
276
torchaudio.save(
    "_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate
)
moto's avatar
moto committed
277
IPython.display.Audio("_assets/output_wavernn.wav")
moto's avatar
moto committed
278
279
280
281
282


######################################################################
# Griffin-Lim
# ~~~~~~~~~~~
283
#
moto's avatar
moto committed
284
285
# Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate
# the vocode object with ``get_vocoder`` method and pass the spectrogram.
286
#
moto's avatar
moto committed
287
288
289
290
291
292
293
294

bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH

processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)

with torch.inference_mode():
295
296
297
298
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
moto's avatar
moto committed
299
300
301
302
303
304
waveforms, lengths = vocoder(spec, spec_lengths)

fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())

305
306
307
308
309
torchaudio.save(
    "_assets/output_griffinlim.wav",
    waveforms[0:1].cpu(),
    sample_rate=vocoder.sample_rate,
)
moto's avatar
moto committed
310
IPython.display.Audio("_assets/output_griffinlim.wav")
moto's avatar
moto committed
311
312
313
314
315


######################################################################
# Waveglow
# ~~~~~~~~
316
#
moto's avatar
moto committed
317
318
319
# Waveglow is a vocoder published by Nvidia. The pretrained weight is
# publishe on Torch Hub. One can instantiate the model using ``torch.hub``
# module.
320
#
moto's avatar
moto committed
321
322
323

# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
324
325
326
327
328
329
330
331
332
333
334
335
336
337
waveglow = torch.hub.load(
    "NVIDIA/DeepLearningExamples:torchhub",
    "nvidia_waveglow",
    model_math="fp32",
    pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
    "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth",  # noqa: E501
    progress=False,
    map_location=device,
)
state_dict = {
    key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()
}
moto's avatar
moto committed
338
339
340
341
342
343
344

waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()

with torch.no_grad():
345
    waveforms = waveglow.infer(spec)
moto's avatar
moto committed
346
347
348
349
350

fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())

moto's avatar
moto committed
351
352
torchaudio.save("_assets/output_waveglow.wav", waveforms[0:1].cpu(), sample_rate=22050)
IPython.display.Audio("_assets/output_waveglow.wav")