Commit d3cea8c9 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'main' into 'main'

增加了pytorch框架下的音频处理模型FastSpeech和ECAPA-TDNN的测试代码

See merge request dcutoolkit/deeplearing/dlexamples_new!31
parents 13a50bfe eb779cd5
import os
import librosa
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
from text import _clean_text
def prepare_align(config):
in_dir = config["path"]["corpus_path"]
out_dir = config["path"]["raw_path"]
sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
cleaners = config["preprocessing"]["text"]["text_cleaners"]
for speaker in tqdm(os.listdir(in_dir)):
for chapter in os.listdir(os.path.join(in_dir, speaker)):
for file_name in os.listdir(os.path.join(in_dir, speaker, chapter)):
if file_name[-4:] != ".wav":
continue
base_name = file_name[:-4]
text_path = os.path.join(
in_dir, speaker, chapter, "{}.normalized.txt".format(base_name)
)
wav_path = os.path.join(
in_dir, speaker, chapter, "{}.wav".format(base_name)
)
with open(text_path) as f:
text = f.readline().strip("\n")
text = _clean_text(text, cleaners)
os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
wav, _ = librosa.load(wav_path, sampling_rate)
wav = wav / max(abs(wav)) * max_wav_value
wavfile.write(
os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
sampling_rate,
wav.astype(np.int16),
)
with open(
os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
"w",
) as f1:
f1.write(text)
\ No newline at end of file
import os
import librosa
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
from text import _clean_text
def prepare_align(config):
in_dir = config["path"]["corpus_path"]
out_dir = config["path"]["raw_path"]
sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
cleaners = config["preprocessing"]["text"]["text_cleaners"]
speaker = "LJSpeech"
with open(os.path.join(in_dir, "metadata.csv"), encoding="utf-8") as f:
for line in tqdm(f):
parts = line.strip().split("|")
base_name = parts[0]
text = parts[2]
text = _clean_text(text, cleaners)
wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name))
if os.path.exists(wav_path):
os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
wav, _ = librosa.load(wav_path, sampling_rate)
wav = wav / max(abs(wav)) * max_wav_value
wavfile.write(
os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
sampling_rate,
wav.astype(np.int16),
)
with open(
os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
"w",
) as f1:
f1.write(text)
\ No newline at end of file
import os
import random
import json
import tgt
import librosa
import numpy as np
import pyworld as pw
from scipy.interpolate import interp1d
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import audio as Audio
class Preprocessor:
def __init__(self, config):
self.config = config
self.in_dir = config["path"]["raw_path"]
self.out_dir = config["path"]["preprocessed_path"]
self.val_size = config["preprocessing"]["val_size"]
self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
self.hop_length = config["preprocessing"]["stft"]["hop_length"]
assert config["preprocessing"]["pitch"]["feature"] in [
"phoneme_level",
"frame_level",
]
assert config["preprocessing"]["energy"]["feature"] in [
"phoneme_level",
"frame_level",
]
self.pitch_phoneme_averaging = (
config["preprocessing"]["pitch"]["feature"] == "phoneme_level"
)
self.energy_phoneme_averaging = (
config["preprocessing"]["energy"]["feature"] == "phoneme_level"
)
self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"]
self.energy_normalization = config["preprocessing"]["energy"]["normalization"]
self.STFT = Audio.stft.TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
def build_from_path(self):
os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True)
os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True)
os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True)
os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True)
print("Processing Data ...")
out = list()
n_frames = 0
pitch_scaler = StandardScaler()
energy_scaler = StandardScaler()
# Compute pitch, energy, duration, and mel-spectrogram
speakers = {}
for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))):
speakers[speaker] = i
for wav_name in os.listdir(os.path.join(self.in_dir, speaker)):
if ".wav" not in wav_name:
continue
basename = wav_name.split(".")[0]
tg_path = os.path.join(
self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
)
if os.path.exists(tg_path):
ret = self.process_utterance(speaker, basename)
if ret is None:
continue
else:
info, pitch, energy, n = ret
out.append(info)
if len(pitch) > 0:
pitch_scaler.partial_fit(pitch.reshape((-1, 1)))
if len(energy) > 0:
energy_scaler.partial_fit(energy.reshape((-1, 1)))
n_frames += n
print("Computing statistic quantities ...")
# Perform normalization if necessary
if self.pitch_normalization:
pitch_mean = pitch_scaler.mean_[0]
pitch_std = pitch_scaler.scale_[0]
else:
# A numerical trick to avoid normalization...
pitch_mean = 0
pitch_std = 1
if self.energy_normalization:
energy_mean = energy_scaler.mean_[0]
energy_std = energy_scaler.scale_[0]
else:
energy_mean = 0
energy_std = 1
pitch_min, pitch_max = self.normalize(
os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std
)
energy_min, energy_max = self.normalize(
os.path.join(self.out_dir, "energy"), energy_mean, energy_std
)
# Save files
with open(os.path.join(self.out_dir, "speakers.json"), "w") as f:
f.write(json.dumps(speakers))
with open(os.path.join(self.out_dir, "stats.json"), "w") as f:
stats = {
"pitch": [
float(pitch_min),
float(pitch_max),
float(pitch_mean),
float(pitch_std),
],
"energy": [
float(energy_min),
float(energy_max),
float(energy_mean),
float(energy_std),
],
}
f.write(json.dumps(stats))
print(
"Total time: {} hours".format(
n_frames * self.hop_length / self.sampling_rate / 3600
)
)
random.shuffle(out)
out = [r for r in out if r is not None]
# Write metadata
with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f:
for m in out[self.val_size :]:
f.write(m + "\n")
with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f:
for m in out[: self.val_size]:
f.write(m + "\n")
return out
def process_utterance(self, speaker, basename):
wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename))
text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename))
tg_path = os.path.join(
self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
)
# Get alignments
textgrid = tgt.io.read_textgrid(tg_path)
phone, duration, start, end = self.get_alignment(
textgrid.get_tier_by_name("phones")
)
text = "{" + " ".join(phone) + "}"
if start >= end:
return None
# Read and trim wav files
wav, _ = librosa.load(wav_path)
wav = wav[
int(self.sampling_rate * start) : int(self.sampling_rate * end)
].astype(np.float32)
# Read raw text
with open(text_path, "r") as f:
raw_text = f.readline().strip("\n")
# Compute fundamental frequency
pitch, t = pw.dio(
wav.astype(np.float64),
self.sampling_rate,
frame_period=self.hop_length / self.sampling_rate * 1000,
)
pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate)
pitch = pitch[: sum(duration)]
if np.sum(pitch != 0) <= 1:
return None
# Compute mel-scale spectrogram and energy
mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT)
mel_spectrogram = mel_spectrogram[:, : sum(duration)]
energy = energy[: sum(duration)]
if self.pitch_phoneme_averaging:
# perform linear interpolation
nonzero_ids = np.where(pitch != 0)[0]
interp_fn = interp1d(
nonzero_ids,
pitch[nonzero_ids],
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
bounds_error=False,
)
pitch = interp_fn(np.arange(0, len(pitch)))
# Phoneme-level average
pos = 0
for i, d in enumerate(duration):
if d > 0:
pitch[i] = np.mean(pitch[pos : pos + d])
else:
pitch[i] = 0
pos += d
pitch = pitch[: len(duration)]
if self.energy_phoneme_averaging:
# Phoneme-level average
pos = 0
for i, d in enumerate(duration):
if d > 0:
energy[i] = np.mean(energy[pos : pos + d])
else:
energy[i] = 0
pos += d
energy = energy[: len(duration)]
# Save files
dur_filename = "{}-duration-{}.npy".format(speaker, basename)
np.save(os.path.join(self.out_dir, "duration", dur_filename), duration)
pitch_filename = "{}-pitch-{}.npy".format(speaker, basename)
np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch)
energy_filename = "{}-energy-{}.npy".format(speaker, basename)
np.save(os.path.join(self.out_dir, "energy", energy_filename), energy)
mel_filename = "{}-mel-{}.npy".format(speaker, basename)
np.save(
os.path.join(self.out_dir, "mel", mel_filename),
mel_spectrogram.T,
)
return (
"|".join([basename, speaker, text, raw_text]),
self.remove_outlier(pitch),
self.remove_outlier(energy),
mel_spectrogram.shape[1],
)
def get_alignment(self, tier):
sil_phones = ["sil", "sp", "spn"]
phones = []
durations = []
start_time = 0
end_time = 0
end_idx = 0
for t in tier._objects:
s, e, p = t.start_time, t.end_time, t.text
# Trim leading silences
if phones == []:
if p in sil_phones:
continue
else:
start_time = s
if p not in sil_phones:
# For ordinary phones
phones.append(p)
end_time = e
end_idx = len(phones)
else:
# For silent phones
phones.append(p)
durations.append(
int(
np.round(e * self.sampling_rate / self.hop_length)
- np.round(s * self.sampling_rate / self.hop_length)
)
)
# Trim tailing silences
phones = phones[:end_idx]
durations = durations[:end_idx]
return phones, durations, start_time, end_time
def remove_outlier(self, values):
values = np.array(values)
p25 = np.percentile(values, 25)
p75 = np.percentile(values, 75)
lower = p25 - 1.5 * (p75 - p25)
upper = p75 + 1.5 * (p75 - p25)
normal_indices = np.logical_and(values > lower, values < upper)
return values[normal_indices]
def normalize(self, in_dir, mean, std):
max_value = np.finfo(np.float64).min
min_value = np.finfo(np.float64).max
for filename in os.listdir(in_dir):
filename = os.path.join(in_dir, filename)
values = (np.load(filename) - mean) / std
np.save(filename, values)
max_value = max(max_value, max(values))
min_value = min(min_value, min(values))
return min_value, max_value
g2p-en == 2.1.0
inflect == 4.1.0
librosa == 0.7.2
matplotlib == 3.2.2
numba == 0.48
numpy == 1.19.0
pypinyin==0.39.0
pyworld == 0.2.10
PyYAML==5.4.1
scikit-learn==0.23.2
scipy == 1.5.0
soundfile==0.10.3.post1
tensorboard == 2.2.2
tgt == 1.4.4
tqdm==4.46.1
unidecode == 1.1.1
\ No newline at end of file
import re
import argparse
from string import punctuation
import torch
import yaml
import numpy as np
from torch.utils.data import DataLoader
from g2p_en import G2p
from pypinyin import pinyin, Style
from utils.model import get_model, get_vocoder
from utils.tools import to_device, synth_samples
from dataset import TextDataset
from text import text_to_sequence
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def read_lexicon(lex_path):
lexicon = {}
with open(lex_path) as f:
for line in f:
temp = re.split(r"\s+", line.strip("\n"))
word = temp[0]
phones = temp[1:]
if word.lower() not in lexicon:
lexicon[word.lower()] = phones
return lexicon
def preprocess_english(text, preprocess_config):
text = text.rstrip(punctuation)
lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
g2p = G2p()
phones = []
words = re.split(r"([,;.\-\?\!\s+])", text)
for w in words:
if w.lower() in lexicon:
phones += lexicon[w.lower()]
else:
phones += list(filter(lambda p: p != " ", g2p(w)))
phones = "{" + "}{".join(phones) + "}"
phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
phones = phones.replace("}{", " ")
print("Raw Text Sequence: {}".format(text))
print("Phoneme Sequence: {}".format(phones))
sequence = np.array(
text_to_sequence(
phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
)
)
return np.array(sequence)
def preprocess_mandarin(text, preprocess_config):
lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
phones = []
pinyins = [
p[0]
for p in pinyin(
text, style=Style.TONE3, strict=False, neutral_tone_with_five=True
)
]
for p in pinyins:
if p in lexicon:
phones += lexicon[p]
else:
phones.append("sp")
phones = "{" + " ".join(phones) + "}"
print("Raw Text Sequence: {}".format(text))
print("Phoneme Sequence: {}".format(phones))
sequence = np.array(
text_to_sequence(
phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
)
)
return np.array(sequence)
def synthesize(model, step, configs, vocoder, batchs, control_values):
preprocess_config, model_config, train_config = configs
pitch_control, energy_control, duration_control = control_values
for batch in batchs:
batch = to_device(batch, device)
with torch.no_grad():
# Forward
output = model(
*(batch[2:]),
p_control=pitch_control,
e_control=energy_control,
d_control=duration_control
)
synth_samples(
batch,
output,
vocoder,
model_config,
preprocess_config,
train_config["path"]["result_path"],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--restore_step", type=int, required=True)
parser.add_argument(
"--mode",
type=str,
choices=["batch", "single"],
required=True,
help="Synthesize a whole dataset or a single sentence",
)
parser.add_argument(
"--source",
type=str,
default=None,
help="path to a source file with format like train.txt and val.txt, for batch mode only",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="raw text to synthesize, for single-sentence mode only",
)
parser.add_argument(
"--speaker_id",
type=int,
default=0,
help="speaker ID for multi-speaker synthesis, for single-sentence mode only",
)
parser.add_argument(
"-p",
"--preprocess_config",
type=str,
required=True,
help="path to preprocess.yaml",
)
parser.add_argument(
"-m", "--model_config", type=str, required=True, help="path to model.yaml"
)
parser.add_argument(
"-t", "--train_config", type=str, required=True, help="path to train.yaml"
)
parser.add_argument(
"--pitch_control",
type=float,
default=1.0,
help="control the pitch of the whole utterance, larger value for higher pitch",
)
parser.add_argument(
"--energy_control",
type=float,
default=1.0,
help="control the energy of the whole utterance, larger value for larger volume",
)
parser.add_argument(
"--duration_control",
type=float,
default=1.0,
help="control the speed of the whole utterance, larger value for slower speaking rate",
)
args = parser.parse_args()
# Check source texts
if args.mode == "batch":
assert args.source is not None and args.text is None
if args.mode == "single":
assert args.source is None and args.text is not None
# Read Config
preprocess_config = yaml.load(
open(args.preprocess_config, "r"), Loader=yaml.FullLoader
)
model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
configs = (preprocess_config, model_config, train_config)
# Get model
model = get_model(args, configs, device, train=False)
# Load vocoder
vocoder = get_vocoder(model_config, device)
# Preprocess texts
if args.mode == "batch":
# Get dataset
dataset = TextDataset(args.source, preprocess_config)
batchs = DataLoader(
dataset,
batch_size=8,
collate_fn=dataset.collate_fn,
)
if args.mode == "single":
ids = raw_texts = [args.text[:100]]
speakers = np.array([args.speaker_id])
if preprocess_config["preprocessing"]["text"]["language"] == "en":
texts = np.array([preprocess_english(args.text, preprocess_config)])
elif preprocess_config["preprocessing"]["text"]["language"] == "zh":
texts = np.array([preprocess_mandarin(args.text, preprocess_config)])
text_lens = np.array([len(texts[0])])
batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]
control_values = args.pitch_control, args.energy_control, args.duration_control
synthesize(model, args.restore_step, configs, vocoder, batchs, control_values)
""" from https://github.com/keithito/tacotron """
import re
from text import cleaners
from text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
"""
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
return sequence
def sequence_to_text(sequence):
"""Converts a sequence of IDs back to a string"""
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s != "_" and s != "~"
""" from https://github.com/keithito/tacotron """
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
# Regular expression matching whitespace:
import re
from unidecode import unidecode
from .numbers import normalize_numbers
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
""" from https://github.com/keithito/tacotron """
import re
valid_symbols = [
"AA",
"AA0",
"AA1",
"AA2",
"AE",
"AE0",
"AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH",
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding="latin-1") as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
"""Returns list of ARPAbet pronunciations of the given word."""
return self._entries.get(word.upper())
_alt_re = re.compile(r"\([0-9]+\)")
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
parts = line.split(" ")
word = re.sub(_alt_re, "", parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(" ")
for part in parts:
if part not in _valid_symbol_set:
return None
return " ".join(parts)
""" from https://github.com/keithito/tacotron """
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(
num, andword="", zero="oh", group=2
).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
initials = [
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"w",
"x",
"y",
"z",
"zh",
]
finals = [
"a1",
"a2",
"a3",
"a4",
"a5",
"ai1",
"ai2",
"ai3",
"ai4",
"ai5",
"an1",
"an2",
"an3",
"an4",
"an5",
"ang1",
"ang2",
"ang3",
"ang4",
"ang5",
"ao1",
"ao2",
"ao3",
"ao4",
"ao5",
"e1",
"e2",
"e3",
"e4",
"e5",
"ei1",
"ei2",
"ei3",
"ei4",
"ei5",
"en1",
"en2",
"en3",
"en4",
"en5",
"eng1",
"eng2",
"eng3",
"eng4",
"eng5",
"er1",
"er2",
"er3",
"er4",
"er5",
"i1",
"i2",
"i3",
"i4",
"i5",
"ia1",
"ia2",
"ia3",
"ia4",
"ia5",
"ian1",
"ian2",
"ian3",
"ian4",
"ian5",
"iang1",
"iang2",
"iang3",
"iang4",
"iang5",
"iao1",
"iao2",
"iao3",
"iao4",
"iao5",
"ie1",
"ie2",
"ie3",
"ie4",
"ie5",
"ii1",
"ii2",
"ii3",
"ii4",
"ii5",
"iii1",
"iii2",
"iii3",
"iii4",
"iii5",
"in1",
"in2",
"in3",
"in4",
"in5",
"ing1",
"ing2",
"ing3",
"ing4",
"ing5",
"iong1",
"iong2",
"iong3",
"iong4",
"iong5",
"iou1",
"iou2",
"iou3",
"iou4",
"iou5",
"o1",
"o2",
"o3",
"o4",
"o5",
"ong1",
"ong2",
"ong3",
"ong4",
"ong5",
"ou1",
"ou2",
"ou3",
"ou4",
"ou5",
"u1",
"u2",
"u3",
"u4",
"u5",
"ua1",
"ua2",
"ua3",
"ua4",
"ua5",
"uai1",
"uai2",
"uai3",
"uai4",
"uai5",
"uan1",
"uan2",
"uan3",
"uan4",
"uan5",
"uang1",
"uang2",
"uang3",
"uang4",
"uang5",
"uei1",
"uei2",
"uei3",
"uei4",
"uei5",
"uen1",
"uen2",
"uen3",
"uen4",
"uen5",
"uo1",
"uo2",
"uo3",
"uo4",
"uo5",
"v1",
"v2",
"v3",
"v4",
"v5",
"van1",
"van2",
"van3",
"van4",
"van5",
"ve1",
"ve2",
"ve3",
"ve4",
"ve5",
"vn1",
"vn2",
"vn3",
"vn4",
"vn5",
]
valid_symbols = initials + finals + ["rr"]
\ No newline at end of file
""" from https://github.com/keithito/tacotron """
"""
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
from text import cmudict, pinyin
_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_silences = ["@sp", "@spn", "@sil"]
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ["@" + s for s in cmudict.valid_symbols]
_pinyin = ["@" + s for s in pinyin.valid_symbols]
# Export all symbols:
symbols = (
[_pad]
+ list(_special)
+ list(_punctuation)
+ list(_letters)
+ _arpabet
+ _pinyin
+ _silences
)
This diff is collapsed.
This diff is collapsed.
PAD = 0
UNK = 1
BOS = 2
EOS = 3
PAD_WORD = "<blank>"
UNK_WORD = "<unk>"
BOS_WORD = "<s>"
EOS_WORD = "</s>"
This diff is collapsed.
This diff is collapsed.
import torch
import torch.nn as nn
import numpy as np
class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
output = torch.bmm(attn, v)
return output, attn
This diff is collapsed.
from .Models import Encoder, Decoder
from .Layers import PostNet
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment