Commit 3f2d46a1 authored by patil-suraj's avatar patil-suraj
Browse files

fix tokenizer

parent 7b55d334
# tokenizer
import re
import os
from shutil import copyfile
import torch
from transformers import PreTrainedTokenizer
......@@ -325,7 +327,7 @@ def _should_keep_symbol(s):
VOCAB_FILES_NAMES = {
"dict_file": "merges.txt",
"dict_file": "dict_file.txt",
}
class GradTTSTokenizer(PreTrainedTokenizer):
......@@ -334,8 +336,18 @@ class GradTTSTokenizer(PreTrainedTokenizer):
def __init__(self, dict_file, **kwargs):
super().__init__(**kwargs)
self.cmu = CMUDict(dict_file)
self.dict_file = dict_file
def __call__(self, text):
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
return x.shape, x_lengths
return x, x_lengths
def save_vocabulary(self, save_directory: str, filename_prefix = None):
dict_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
)
copyfile(self.dict_file, dict_file)
return (dict_file, )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment