"server/text_generation/models/t5.py" did not exist on "1f570d181f4836a430f9b92f001a7b834ea561e3"
text_preprocessing.py 6.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# *****************************************************************************
# Copyright (c) 2017 Keith Ito
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# *****************************************************************************
"""
Modified from https://github.com/keithito/tacotron
"""

import re
28
from typing import List, Optional, Union
29

30
from torchaudio.datasets import CMUDict
31
from unidecode import unidecode
32
33
34
35
36

from .numbers import normalize_numbers


# Regular expression matching whitespace:
37
_whitespace_re = re.compile(r"\s+")
38
39

# List of (regular expression, replacement) pairs for abbreviations:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
_abbreviations = [
    (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
    for x in [
        ("mrs", "misess"),
        ("mr", "mister"),
        ("dr", "doctor"),
        ("st", "saint"),
        ("co", "company"),
        ("jr", "junior"),
        ("maj", "major"),
        ("gen", "general"),
        ("drs", "doctors"),
        ("rev", "reverend"),
        ("lt", "lieutenant"),
        ("hon", "honorable"),
        ("sgt", "sergeant"),
        ("capt", "captain"),
        ("esq", "esquire"),
        ("ltd", "limited"),
        ("col", "colonel"),
        ("ft", "fort"),
    ]
]

_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "abcdefghijklmnopqrstuvwxyz"
68
69

symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
70
_phonemizer = None
71
72


mayp777's avatar
UPDATE  
mayp777 committed
73
74
available_symbol_set = {"english_characters", "english_phonemes"}
available_phonemizers = {"DeepPhonemizer"}
75
76


77
def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
78
79
80
81
82
    if symbol_list == "english_characters":
        return [_pad] + list(_special) + list(_punctuation) + list(_letters)
    elif symbol_list == "english_phonemes":
        return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
    else:
83
84
85
86
        raise ValueError(
            f"The `symbol_list` {symbol_list} is not supported."
            f"Supported `symbol_list` includes {available_symbol_set}."
        )
87
88
89
90
91


def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
    if phonemizer == "DeepPhonemizer":
        from dp.phonemizer import Phonemizer
92

93
        global _phonemizer
94
95
        _other_symbols = "".join(list(_special) + list(_punctuation))
        _phone_symbols_re = r"(\[[A-Z]+?\]|" + "[" + _other_symbols + "])"  # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
96
97
98
99
100
101
102
103
104

        if _phonemizer is None:
            # using a global variable so that we don't have to relode checkpoint
            # everytime this function is called
            _phonemizer = Phonemizer.from_checkpoint(checkpoint)

        # Example:
        # sent = "hello world!"
        # '[HH][AH][L][OW] [W][ER][L][D]!'
105
        sent = _phonemizer(sent, lang="en_us")
106
107
108
109
110
111
112
113
114

        # ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
        ret = re.findall(_phone_symbols_re, sent)

        # ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!']
        ret = [r.replace("[", "").replace("]", "") for r in ret]

        return ret
    else:
115
116
117
        raise ValueError(
            f"The `phonemizer` {phonemizer} is not supported. " "Supported `symbol_list` includes `'DeepPhonemizer'`."
        )
118
119


120
121
122
123
124
125
126
127
def text_to_sequence(
    sent: str,
    symbol_list: Union[str, List[str]] = "english_characters",
    phonemizer: Optional[str] = "DeepPhonemizer",
    checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
    cmudict_root: Optional[str] = "./",
) -> List[int]:
    r"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
128

129
    Args:
130
        sent (str): The input sentence to convert to a sequence.
131
132
133
        symbol_list (str or List of string, optional): When the input is a string, available options include
            "english_characters" and "english_phonemes". When the input is a list of string, ``symbol_list`` will
            directly be used as the symbol to encode. (Default: "english_characters")
134
        phonemizer (str or None, optional): The phonemizer to use. Only used when ``symbol_list`` is "english_phonemes".
135
            Available options include "DeepPhonemizer". (Default: "DeepPhonemizer")
136
137
138
139
        checkpoint (str or None, optional): The path to the checkpoint of the phonemizer. Only used when
            ``symbol_list`` is "english_phonemes". (Default: "./en_us_cmudict_forward.pt")
        cmudict_root (str or None, optional): The path to the directory where the CMUDict dataset is found or
            downloaded. Only used when ``symbol_list`` is "english_phonemes". (Default: "./")
140

141
    Returns:
142
        List of integers corresponding to the symbols in the sentence.
143
144
145
146
147
148

    Examples:
        >>> text_to_sequence("hello world!", "english_characters")
        [19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
        >>> text_to_sequence("hello world!", "english_phonemes")
        [54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
149
    """
150
151
152
153
    if symbol_list == "english_phonemes":
        if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
            raise ValueError(
                "When `symbol_list` is 'english_phonemes', "
154
155
                "all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided."
            )
156

157
158
159
160
161
    sent = unidecode(sent)  # convert to ascii
    sent = sent.lower()  # lower case
    sent = normalize_numbers(sent)  # expand numbers
    for regex, replacement in _abbreviations:  # expand abbreviations
        sent = re.sub(regex, replacement, sent)
162
    sent = re.sub(_whitespace_re, " ", sent)  # collapse whitespace
163

164
165
166
167
168
169
170
171
172
    if isinstance(symbol_list, list):
        symbols = symbol_list
    elif isinstance(symbol_list, str):
        symbols = get_symbol_list(symbol_list, cmudict_root=cmudict_root)
        if symbol_list == "english_phonemes":
            sent = word_to_phonemes(sent, phonemizer=phonemizer, checkpoint=checkpoint)

    _symbol_to_id = {s: i for i, s in enumerate(symbols)}

173
    return [_symbol_to_id[s] for s in sent if s in _symbol_to_id]