# tokenizer import re import os from shutil import copyfile import torch from transformers import PreTrainedTokenizer try: from unidecode import unidecode except: print("unidecode is not installed") pass try: import inflect except: print("inflect is not installed") pass valid_symbols = [ 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' ] _valid_symbol_set = set(valid_symbols) def intersperse(lst, item): # Adds blank symbol result = [item] * (len(lst) * 2 + 1) result[1::2] = lst return result class CMUDict: def __init__(self, file_or_path, keep_ambiguous=True): if isinstance(file_or_path, str): with open(file_or_path, encoding='latin-1') as f: entries = _parse_cmudict(f) else: entries = _parse_cmudict(file_or_path) if not keep_ambiguous: entries = {word: pron for word, pron in entries.items() if len(pron) == 1} self._entries = entries def __len__(self): return len(self._entries) def lookup(self, word): return self._entries.get(word.upper()) _alt_re = re.compile(r'\([0-9]+\)') def _parse_cmudict(file): cmudict = {} for line in file: if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): parts = line.split(' ') word = re.sub(_alt_re, '', parts[0]) pronunciation = _get_pronunciation(parts[1]) if pronunciation: if word in cmudict: cmudict[word].append(pronunciation) else: cmudict[word] = [pronunciation] return cmudict def _get_pronunciation(s): parts = s.strip().split(' ') for part in parts: if part not in _valid_symbol_set: return None return ' '.join(parts) _whitespace_re = re.compile(r'\s+') _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ ('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]] def expand_abbreviations(text): for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text def expand_numbers(text): return normalize_numbers(text) def lowercase(text): return text.lower() def collapse_whitespace(text): return re.sub(_whitespace_re, ' ', text) def convert_to_ascii(text): return unidecode(text) def basic_cleaners(text): text = lowercase(text) text = collapse_whitespace(text) return text def transliteration_cleaners(text): text = convert_to_ascii(text) text = lowercase(text) text = collapse_whitespace(text) return text def english_cleaners(text): text = convert_to_ascii(text) text = lowercase(text) text = expand_numbers(text) text = expand_abbreviations(text) text = collapse_whitespace(text) return text _inflect = inflect.engine() _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') _number_re = re.compile(r'[0-9]+') def _remove_commas(m): return m.group(1).replace(',', '') def _expand_decimal_point(m): return m.group(1).replace('.', ' point ') def _expand_dollars(m): match = m.group(1) parts = match.split('.') if len(parts) > 2: return match + ' dollars' dollars = int(parts[0]) if parts[0] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if dollars and cents: dollar_unit = 'dollar' if dollars == 1 else 'dollars' cent_unit = 'cent' if cents == 1 else 'cents' return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) elif dollars: dollar_unit = 'dollar' if dollars == 1 else 'dollars' return '%s %s' % (dollars, dollar_unit) elif cents: cent_unit = 'cent' if cents == 1 else 'cents' return '%s %s' % (cents, cent_unit) else: return 'zero dollars' def _expand_ordinal(m): return _inflect.number_to_words(m.group(0)) def _expand_number(m): num = int(m.group(0)) if num > 1000 and num < 3000: if num == 2000: return 'two thousand' elif num > 2000 and num < 2010: return 'two thousand ' + _inflect.number_to_words(num % 100) elif num % 100 == 0: return _inflect.number_to_words(num // 100) + ' hundred' else: return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') else: return _inflect.number_to_words(num, andword='') def normalize_numbers(text): text = re.sub(_comma_number_re, _remove_commas, text) text = re.sub(_pounds_re, r'\1 pounds', text) text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text) text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_number_re, _expand_number, text) return text """ from https://github.com/keithito/tacotron """ _pad = '_' _punctuation = '!\'(),.:;? ' _special = '-' _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' # Prepend "@" to ARPAbet symbols to ensure uniqueness: _arpabet = ['@' + s for s in valid_symbols] # Export all symbols: symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') def get_arpabet(word, dictionary): word_arpabet = dictionary.lookup(word) if word_arpabet is not None: return "{" + word_arpabet[0] + "}" else: return word def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." Args: text: string to convert to a sequence cleaner_names: names of the cleaner functions to run the text through dictionary: arpabet class with arpabet dictionary Returns: List of integers corresponding to the symbols in the text ''' sequence = [] space = _symbols_to_sequence(' ') # Check for curly braces and treat their contents as ARPAbet: while len(text): m = _curly_re.match(text) if not m: clean_text = _clean_text(text, cleaner_names) if dictionary is not None: clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] for i in range(len(clean_text)): t = clean_text[i] if t.startswith("{"): sequence += _arpabet_to_sequence(t[1:-1]) else: sequence += _symbols_to_sequence(t) sequence += space else: sequence += _symbols_to_sequence(clean_text) break sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) # remove trailing space if dictionary is not None: sequence = sequence[:-1] if sequence[-1] == space[0] else sequence return sequence def sequence_to_text(sequence): '''Converts a sequence of IDs back to a string''' result = '' for symbol_id in sequence: if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: if len(s) > 1 and s[0] == '@': s = '{%s}' % s[1:] result += s return result.replace('}{', ' ') def _clean_text(text, cleaner_names): for cleaner in cleaner_names: text = cleaner(text) return text def _symbols_to_sequence(symbols): return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] def _arpabet_to_sequence(text): return _symbols_to_sequence(['@' + s for s in text.split()]) def _should_keep_symbol(s): return s in _symbol_to_id and s != '_' and s != '~' VOCAB_FILES_NAMES = { "dict_file": "dict_file.txt", } class GradTTSTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES def __init__(self, dict_file, **kwargs): super().__init__(**kwargs) self.cmu = CMUDict(dict_file) self.dict_file = dict_file def __call__(self, text): x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) return x, x_lengths def save_vocabulary(self, save_directory: str, filename_prefix = None): dict_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"] ) copyfile(self.dict_file, dict_file) return (dict_file, )