grad_tts_utils.py 10.3 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
2
# tokenizer

patil-suraj's avatar
patil-suraj committed
3
import os
patil-suraj's avatar
style  
patil-suraj committed
4
import re
patil-suraj's avatar
patil-suraj committed
5
from shutil import copyfile
patil-suraj's avatar
patil-suraj committed
6
7

import torch
8

patil-suraj's avatar
style  
patil-suraj committed
9

10
11
12
13
try:
    from transformers import PreTrainedTokenizer
except:
    print("transformers is not installed")
patil-suraj's avatar
patil-suraj committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

try:
    from unidecode import unidecode
except:
    print("unidecode is not installed")
    pass

try:
    import inflect
except:
    print("inflect is not installed")
    pass


valid_symbols = [
patil-suraj's avatar
style  
patil-suraj committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    "AA",
    "AA0",
    "AA1",
    "AA2",
    "AE",
    "AE0",
    "AE1",
    "AE2",
    "AH",
    "AH0",
    "AH1",
    "AH2",
    "AO",
    "AO0",
    "AO1",
    "AO2",
    "AW",
    "AW0",
    "AW1",
    "AW2",
    "AY",
    "AY0",
    "AY1",
    "AY2",
    "B",
    "CH",
    "D",
    "DH",
    "EH",
    "EH0",
    "EH1",
    "EH2",
    "ER",
    "ER0",
    "ER1",
    "ER2",
    "EY",
    "EY0",
    "EY1",
    "EY2",
    "F",
    "G",
    "HH",
    "IH",
    "IH0",
    "IH1",
    "IH2",
    "IY",
    "IY0",
    "IY1",
    "IY2",
    "JH",
    "K",
    "L",
    "M",
    "N",
    "NG",
    "OW",
    "OW0",
    "OW1",
    "OW2",
    "OY",
    "OY0",
    "OY1",
    "OY2",
    "P",
    "R",
    "S",
    "SH",
    "T",
    "TH",
    "UH",
    "UH0",
    "UH1",
    "UH2",
    "UW",
    "UW0",
    "UW1",
    "UW2",
    "V",
    "W",
    "Y",
    "Z",
    "ZH",
patil-suraj's avatar
patil-suraj committed
113
114
115
116
]

_valid_symbol_set = set(valid_symbols)

patil-suraj's avatar
style  
patil-suraj committed
117

patil-suraj's avatar
patil-suraj committed
118
119
120
121
122
123
124
125
126
127
def intersperse(lst, item):
    # Adds blank symbol
    result = [item] * (len(lst) * 2 + 1)
    result[1::2] = lst
    return result


class CMUDict:
    def __init__(self, file_or_path, keep_ambiguous=True):
        if isinstance(file_or_path, str):
patil-suraj's avatar
style  
patil-suraj committed
128
            with open(file_or_path, encoding="latin-1") as f:
patil-suraj's avatar
patil-suraj committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                entries = _parse_cmudict(f)
        else:
            entries = _parse_cmudict(file_or_path)
        if not keep_ambiguous:
            entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
        self._entries = entries

    def __len__(self):
        return len(self._entries)

    def lookup(self, word):
        return self._entries.get(word.upper())


patil-suraj's avatar
style  
patil-suraj committed
143
_alt_re = re.compile(r"\([0-9]+\)")
patil-suraj's avatar
patil-suraj committed
144
145
146
147
148


def _parse_cmudict(file):
    cmudict = {}
    for line in file:
patil-suraj's avatar
style  
patil-suraj committed
149
150
151
        if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
            parts = line.split("  ")
            word = re.sub(_alt_re, "", parts[0])
patil-suraj's avatar
patil-suraj committed
152
153
154
155
156
157
158
159
160
161
            pronunciation = _get_pronunciation(parts[1])
            if pronunciation:
                if word in cmudict:
                    cmudict[word].append(pronunciation)
                else:
                    cmudict[word] = [pronunciation]
    return cmudict


def _get_pronunciation(s):
patil-suraj's avatar
style  
patil-suraj committed
162
    parts = s.strip().split(" ")
patil-suraj's avatar
patil-suraj committed
163
164
165
    for part in parts:
        if part not in _valid_symbol_set:
            return None
patil-suraj's avatar
style  
patil-suraj committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    return " ".join(parts)


_whitespace_re = re.compile(r"\s+")

_abbreviations = [
    (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
    for x in [
        ("mrs", "misess"),
        ("mr", "mister"),
        ("dr", "doctor"),
        ("st", "saint"),
        ("co", "company"),
        ("jr", "junior"),
        ("maj", "major"),
        ("gen", "general"),
        ("drs", "doctors"),
        ("rev", "reverend"),
        ("lt", "lieutenant"),
        ("hon", "honorable"),
        ("sgt", "sergeant"),
        ("capt", "captain"),
        ("esq", "esquire"),
        ("ltd", "limited"),
        ("col", "colonel"),
        ("ft", "fort"),
    ]
]
patil-suraj's avatar
patil-suraj committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210


def expand_abbreviations(text):
    for regex, replacement in _abbreviations:
        text = re.sub(regex, replacement, text)
    return text


def expand_numbers(text):
    return normalize_numbers(text)


def lowercase(text):
    return text.lower()


def collapse_whitespace(text):
patil-suraj's avatar
style  
patil-suraj committed
211
    return re.sub(_whitespace_re, " ", text)
patil-suraj's avatar
patil-suraj committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


def convert_to_ascii(text):
    return unidecode(text)


def basic_cleaners(text):
    text = lowercase(text)
    text = collapse_whitespace(text)
    return text


def transliteration_cleaners(text):
    text = convert_to_ascii(text)
    text = lowercase(text)
    text = collapse_whitespace(text)
    return text


def english_cleaners(text):
    text = convert_to_ascii(text)
    text = lowercase(text)
    text = expand_numbers(text)
    text = expand_abbreviations(text)
    text = collapse_whitespace(text)
    return text

patil-suraj's avatar
patil-suraj committed
239
240
241
242
243
try:
    _inflect = inflect.engine()
except:
    print("inflect is not installed")
    _inflect = None
patil-suraj's avatar
patil-suraj committed
244

patil-suraj's avatar
style  
patil-suraj committed
245
246
247
248
249
250
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
patil-suraj's avatar
patil-suraj committed
251
252
253


def _remove_commas(m):
patil-suraj's avatar
style  
patil-suraj committed
254
    return m.group(1).replace(",", "")
patil-suraj's avatar
patil-suraj committed
255
256
257


def _expand_decimal_point(m):
patil-suraj's avatar
style  
patil-suraj committed
258
    return m.group(1).replace(".", " point ")
patil-suraj's avatar
patil-suraj committed
259
260
261
262


def _expand_dollars(m):
    match = m.group(1)
patil-suraj's avatar
style  
patil-suraj committed
263
    parts = match.split(".")
patil-suraj's avatar
patil-suraj committed
264
    if len(parts) > 2:
patil-suraj's avatar
style  
patil-suraj committed
265
        return match + " dollars"
patil-suraj's avatar
patil-suraj committed
266
267
268
    dollars = int(parts[0]) if parts[0] else 0
    cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
    if dollars and cents:
patil-suraj's avatar
style  
patil-suraj committed
269
270
271
        dollar_unit = "dollar" if dollars == 1 else "dollars"
        cent_unit = "cent" if cents == 1 else "cents"
        return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
patil-suraj's avatar
patil-suraj committed
272
    elif dollars:
patil-suraj's avatar
style  
patil-suraj committed
273
274
        dollar_unit = "dollar" if dollars == 1 else "dollars"
        return "%s %s" % (dollars, dollar_unit)
patil-suraj's avatar
patil-suraj committed
275
    elif cents:
patil-suraj's avatar
style  
patil-suraj committed
276
277
        cent_unit = "cent" if cents == 1 else "cents"
        return "%s %s" % (cents, cent_unit)
patil-suraj's avatar
patil-suraj committed
278
    else:
patil-suraj's avatar
style  
patil-suraj committed
279
        return "zero dollars"
patil-suraj's avatar
patil-suraj committed
280
281
282
283
284
285
286
287
288
289


def _expand_ordinal(m):
    return _inflect.number_to_words(m.group(0))


def _expand_number(m):
    num = int(m.group(0))
    if num > 1000 and num < 3000:
        if num == 2000:
patil-suraj's avatar
style  
patil-suraj committed
290
            return "two thousand"
patil-suraj's avatar
patil-suraj committed
291
        elif num > 2000 and num < 2010:
patil-suraj's avatar
style  
patil-suraj committed
292
            return "two thousand " + _inflect.number_to_words(num % 100)
patil-suraj's avatar
patil-suraj committed
293
        elif num % 100 == 0:
patil-suraj's avatar
style  
patil-suraj committed
294
            return _inflect.number_to_words(num // 100) + " hundred"
patil-suraj's avatar
patil-suraj committed
295
        else:
patil-suraj's avatar
style  
patil-suraj committed
296
            return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
patil-suraj's avatar
patil-suraj committed
297
    else:
patil-suraj's avatar
style  
patil-suraj committed
298
        return _inflect.number_to_words(num, andword="")
patil-suraj's avatar
patil-suraj committed
299
300
301
302


def normalize_numbers(text):
    text = re.sub(_comma_number_re, _remove_commas, text)
patil-suraj's avatar
style  
patil-suraj committed
303
    text = re.sub(_pounds_re, r"\1 pounds", text)
patil-suraj's avatar
patil-suraj committed
304
305
306
307
308
309
    text = re.sub(_dollars_re, _expand_dollars, text)
    text = re.sub(_decimal_number_re, _expand_decimal_point, text)
    text = re.sub(_ordinal_re, _expand_ordinal, text)
    text = re.sub(_number_re, _expand_number, text)
    return text

patil-suraj's avatar
style  
patil-suraj committed
310

patil-suraj's avatar
patil-suraj committed
311
312
313
""" from https://github.com/keithito/tacotron """


patil-suraj's avatar
style  
patil-suraj committed
314
315
316
317
_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
patil-suraj's avatar
patil-suraj committed
318
319

# Prepend "@" to ARPAbet symbols to ensure uniqueness:
patil-suraj's avatar
style  
patil-suraj committed
320
_arpabet = ["@" + s for s in valid_symbols]
patil-suraj's avatar
patil-suraj committed
321
322
323
324
325
326
327
328

# Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet


_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}

patil-suraj's avatar
style  
patil-suraj committed
329
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
patil-suraj's avatar
patil-suraj committed
330
331
332
333
334
335
336
337
338
339
340


def get_arpabet(word, dictionary):
    word_arpabet = dictionary.lookup(word)
    if word_arpabet is not None:
        return "{" + word_arpabet[0] + "}"
    else:
        return word


def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
patil-suraj's avatar
style  
patil-suraj committed
341
    """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
patil-suraj's avatar
patil-suraj committed
342
343
344
345
346
347
348
349
350
351
352

    The text can optionally have ARPAbet sequences enclosed in curly braces embedded
    in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."

    Args:
      text: string to convert to a sequence
      cleaner_names: names of the cleaner functions to run the text through
      dictionary: arpabet class with arpabet dictionary

    Returns:
      List of integers corresponding to the symbols in the text
patil-suraj's avatar
style  
patil-suraj committed
353
    """
patil-suraj's avatar
patil-suraj committed
354
    sequence = []
patil-suraj's avatar
style  
patil-suraj committed
355
    space = _symbols_to_sequence(" ")
patil-suraj's avatar
patil-suraj committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    # Check for curly braces and treat their contents as ARPAbet:
    while len(text):
        m = _curly_re.match(text)
        if not m:
            clean_text = _clean_text(text, cleaner_names)
            if dictionary is not None:
                clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
                for i in range(len(clean_text)):
                    t = clean_text[i]
                    if t.startswith("{"):
                        sequence += _arpabet_to_sequence(t[1:-1])
                    else:
                        sequence += _symbols_to_sequence(t)
                    sequence += space
            else:
                sequence += _symbols_to_sequence(clean_text)
            break
        sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
        sequence += _arpabet_to_sequence(m.group(2))
        text = m.group(3)
patil-suraj's avatar
style  
patil-suraj committed
376

patil-suraj's avatar
patil-suraj committed
377
378
379
380
381
382
383
    # remove trailing space
    if dictionary is not None:
        sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
    return sequence


def sequence_to_text(sequence):
patil-suraj's avatar
style  
patil-suraj committed
384
385
    """Converts a sequence of IDs back to a string"""
    result = ""
patil-suraj's avatar
patil-suraj committed
386
387
388
389
    for symbol_id in sequence:
        if symbol_id in _id_to_symbol:
            s = _id_to_symbol[symbol_id]
            # Enclose ARPAbet back in curly braces:
patil-suraj's avatar
style  
patil-suraj committed
390
391
            if len(s) > 1 and s[0] == "@":
                s = "{%s}" % s[1:]
patil-suraj's avatar
patil-suraj committed
392
            result += s
patil-suraj's avatar
style  
patil-suraj committed
393
    return result.replace("}{", " ")
patil-suraj's avatar
patil-suraj committed
394
395
396
397
398
399
400
401
402
403
404
405
406


def _clean_text(text, cleaner_names):
    for cleaner in cleaner_names:
        text = cleaner(text)
    return text


def _symbols_to_sequence(symbols):
    return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]


def _arpabet_to_sequence(text):
patil-suraj's avatar
style  
patil-suraj committed
407
    return _symbols_to_sequence(["@" + s for s in text.split()])
patil-suraj's avatar
patil-suraj committed
408
409
410


def _should_keep_symbol(s):
patil-suraj's avatar
style  
patil-suraj committed
411
    return s in _symbol_to_id and s != "_" and s != "~"
patil-suraj's avatar
patil-suraj committed
412
413
414


VOCAB_FILES_NAMES = {
patil-suraj's avatar
patil-suraj committed
415
    "dict_file": "dict_file.txt",
patil-suraj's avatar
patil-suraj committed
416
417
}

patil-suraj's avatar
style  
patil-suraj committed
418

patil-suraj's avatar
patil-suraj committed
419
420
class GradTTSTokenizer(PreTrainedTokenizer):
    vocab_files_names = VOCAB_FILES_NAMES
patil-suraj's avatar
patil-suraj committed
421

patil-suraj's avatar
patil-suraj committed
422
423
424
    def __init__(self, dict_file, **kwargs):
        super().__init__(**kwargs)
        self.cmu = CMUDict(dict_file)
patil-suraj's avatar
patil-suraj committed
425
        self.dict_file = dict_file
patil-suraj's avatar
style  
patil-suraj committed
426

patil-suraj's avatar
patil-suraj committed
427
428
429
    def __call__(self, text):
        x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
        x_lengths = torch.LongTensor([x.shape[-1]])
patil-suraj's avatar
patil-suraj committed
430
        return x, x_lengths
patil-suraj's avatar
style  
patil-suraj committed
431
432

    def save_vocabulary(self, save_directory: str, filename_prefix=None):
patil-suraj's avatar
patil-suraj committed
433
434
435
436
437
        dict_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
        )

        copyfile(self.dict_file, dict_file)
patil-suraj's avatar
style  
patil-suraj committed
438
439

        return (dict_file,)