Commit ee10550a authored by liugh5's avatar liugh5
Browse files

Initial commit

parents
Pipeline #790 canceled with stages
"""
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
"""
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
emotion_types = [
"emotion_none",
"emotion_neutral",
"emotion_angry",
"emotion_disgust",
"emotion_fear",
"emotion_happy",
"emotion_sad",
"emotion_surprise",
"emotion_calm",
"emotion_gentle",
"emotion_relax",
"emotion_lyrical",
"emotion_serious",
"emotion_disgruntled",
"emotion_satisfied",
"emotion_disappointed",
"emotion_excited",
"emotion_anxiety",
"emotion_jealousy",
"emotion_hate",
"emotion_pity",
"emotion_pleasure",
"emotion_arousal",
"emotion_dominance",
"emotion_placeholder1",
"emotion_placeholder2",
"emotion_placeholder3",
"emotion_placeholder4",
"emotion_placeholder5",
"emotion_placeholder6",
"emotion_placeholder7",
"emotion_placeholder8",
"emotion_placeholder9",
]
import xml.etree.ElementTree as ET
from kantts.preprocess.languages import languages
import logging
import os
syllable_flags = [
"s_begin",
"s_end",
"s_none",
"s_both",
"s_middle",
]
word_segments = [
"word_begin",
"word_end",
"word_middle",
"word_both",
"word_none",
]
LANGUAGES_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
"preprocess",
"languages",
)
def parse_phoneset(phoneset_file):
"""Parse a phoneset file and return a list of symbols.
Args:
phoneset_file (str): Path to the phoneset file.
Returns:
list: A list of phones.
"""
ns = "{http://schemas.alibaba-inc.com/tts}"
phone_lst = []
phoneset_root = ET.parse(phoneset_file).getroot()
for phone_node in phoneset_root.findall(ns + "phone"):
phone_lst.append(phone_node.find(ns + "name").text)
for i in range(1, 5):
phone_lst.append("#{}".format(i))
return phone_lst
def parse_tonelist(tonelist_file):
"""Parse a tonelist file and return a list of tones.
Args:
tonelist_file (str): Path to the tonelist file.
Returns:
dict: A dictionary of tones.
"""
tone_lst = []
with open(tonelist_file, "r") as f:
lines = f.readlines()
for line in lines:
tone = line.strip()
if tone != "":
tone_lst.append("tone{}".format(tone))
else:
tone_lst.append("tone_none")
return tone_lst
def get_language_symbols(language):
"""Get symbols of a language.
Args:
language (str): Language name.
"""
language_dict = languages.get(language, None)
if language_dict is None:
logging.error("Language %s not supported. Using PinYin as default", language)
language_dict = languages["PinYin"]
language = "PinYin"
language_dir = os.path.join(LANGUAGES_DIR, language)
phoneset_file = os.path.join(language_dir, language_dict["phoneset_path"])
tonelist_file = os.path.join(language_dir, language_dict["tonelist_path"])
phones = parse_phoneset(phoneset_file)
tones = parse_tonelist(tonelist_file)
return phones, tones, syllable_flags, word_segments
import abc
import os
import shutil
import re
import numpy as np
from . import cleaners as cleaners
from .emotion_types import emotion_types
from .lang_symbols import get_language_symbols
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
def get_fpdict(config):
# eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file.
default_sp = config["linguistic_unit"]["speaker_list"].split(",")[0]
en_sy = f"{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{en_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}" # NOQA: E501
a_sy = f"{{ga$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{a_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}" # NOQA: E501
e_sy = f"{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{e_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}" # NOQA: E501
ling_unit = KanTtsLinguisticUnit(config)
en_lings = ling_unit.encode_symbol_sequence(en_sy)
a_lings = ling_unit.encode_symbol_sequence(a_sy)
e_lings = ling_unit.encode_symbol_sequence(e_sy)
en_ling = np.stack(en_lings, axis=1)[:3, :4]
a_ling = np.stack(a_lings, axis=1)[:3, :4]
e_ling = np.stack(e_lings, axis=1)[:3, :4]
fp_dict = {1: en_ling, 2: a_ling, 3: e_ling}
return fp_dict
class LinguisticBaseUnit(abc.ABC):
def set_config_params(self, config_params):
self.config_params = config_params
def save(self, config, config_name, path):
"""Save config to file"""
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))
class KanTtsLinguisticUnit(LinguisticBaseUnit):
def __init__(self, config):
super(KanTtsLinguisticUnit, self).__init__()
# special symbol
self._pad = "_"
self._eos = "~"
self._mask = "@[MASK]"
self.unit_config = config["linguistic_unit"]
self.lang_type = self.unit_config.get("language", "PinYin")
(
self.lang_phones,
self.lang_tones,
self.lang_syllable_flags,
self.lang_word_segments,
) = get_language_symbols(self.lang_type)
self._cleaner_names = [
x.strip() for x in self.unit_config["cleaners"].split(",")
]
_lfeat_type_list = self.unit_config["lfeat_type_list"].strip().split(",")
self._lfeat_type_list = _lfeat_type_list
self.fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
if self.fp_enable:
self._fpadd_lfeat_type_list = [_lfeat_type_list[0], _lfeat_type_list[4]]
self.build()
def using_byte(self):
return "byte_index" in self._lfeat_type_list
def get_unit_size(self):
ling_unit_size = {}
if self.using_byte():
ling_unit_size["byte_index"] = len(self.byte_index)
else:
ling_unit_size["sy"] = len(self.sy)
ling_unit_size["tone"] = len(self.tone)
ling_unit_size["syllable_flag"] = len(self.syllable_flag)
ling_unit_size["word_segment"] = len(self.word_segment)
if "emo_category" in self._lfeat_type_list:
ling_unit_size["emotion"] = len(self.emo_category)
if "speaker_category" in self._lfeat_type_list:
ling_unit_size["speaker"] = len(self.speaker)
return ling_unit_size
def build(self):
self._sub_unit_dim = {}
self._sub_unit_pad = {}
if self.using_byte():
# Export all byte indices:
self.byte_index = ["@" + str(idx) for idx in range(256)] + [
self._pad,
self._eos,
self._mask,
]
self._byte_index_to_id = {s: i for i, s in enumerate(self.byte_index)}
self._id_to_byte_index = {i: s for i, s in enumerate(self.byte_index)}
self._sub_unit_dim["byte_index"] = len(self.byte_index)
self._sub_unit_pad["byte_index"] = self._byte_index_to_id["_"]
else:
# sy sub-unit
_characters = ""
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ['@' + s for s in cmudict.valid_symbols]
_arpabet = ["@" + s for s in self.lang_phones]
# Export all symbols:
self.sy = list(_characters) + _arpabet + [self._pad, self._eos, self._mask]
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
self._sub_unit_dim["sy"] = len(self.sy)
self._sub_unit_pad["sy"] = self._sy_to_id["_"]
# tone sub-unit
_characters = ""
# Export all tones:
self.tone = (
list(_characters) + self.lang_tones + [self._pad, self._eos, self._mask]
)
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
self._sub_unit_dim["tone"] = len(self.tone)
self._sub_unit_pad["tone"] = self._tone_to_id["_"]
# syllable flag sub-unit
_characters = ""
# Export all syllable_flags:
self.syllable_flag = (
list(_characters)
+ self.lang_syllable_flags
+ [self._pad, self._eos, self._mask]
)
self._syllable_flag_to_id = {s: i for i, s in enumerate(self.syllable_flag)}
self._id_to_syllable_flag = {i: s for i, s in enumerate(self.syllable_flag)}
self._sub_unit_dim["syllable_flag"] = len(self.syllable_flag)
self._sub_unit_pad["syllable_flag"] = self._syllable_flag_to_id["_"]
# word segment sub-unit
_characters = ""
# Export all syllable_flags:
self.word_segment = (
list(_characters)
+ self.lang_word_segments
+ [self._pad, self._eos, self._mask]
)
self._word_segment_to_id = {s: i for i, s in enumerate(self.word_segment)}
self._id_to_word_segment = {i: s for i, s in enumerate(self.word_segment)}
self._sub_unit_dim["word_segment"] = len(self.word_segment)
self._sub_unit_pad["word_segment"] = self._word_segment_to_id["_"]
if "emo_category" in self._lfeat_type_list:
# emotion category sub-unit
_characters = ""
self.emo_category = (
list(_characters) + emotion_types + [self._pad, self._eos, self._mask]
)
self._emo_category_to_id = {s: i for i, s in enumerate(self.emo_category)}
self._id_to_emo_category = {i: s for i, s in enumerate(self.emo_category)}
self._sub_unit_dim["emo_category"] = len(self.emo_category)
self._sub_unit_pad["emo_category"] = self._emo_category_to_id["_"]
if "speaker_category" in self._lfeat_type_list:
# speaker category sub-unit
_characters = ""
_ch_speakers = self.unit_config["speaker_list"].strip().split(",")
# Export all syllable_flags:
self.speaker = (
list(_characters) + _ch_speakers + [self._pad, self._eos, self._mask]
)
self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)}
self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)}
self._sub_unit_dim["speaker_category"] = len(self._speaker_to_id)
self._sub_unit_pad["speaker_category"] = self._speaker_to_id["_"]
def encode_symbol_sequence(self, lfeat_symbol):
lfeat_symbol = lfeat_symbol.strip().split(" ")
lfeat_symbol_separate = [""] * int(len(self._lfeat_type_list))
for this_lfeat_symbol in lfeat_symbol:
this_lfeat_symbol = this_lfeat_symbol.strip("{").strip("}").split("$")
# if len(this_lfeat_symbol) > len(self._lfeat_type_list):
# raise Exception(
# 'Length of this_lfeat_symbol in training data is longer than the length of lfeat_type_list, '\
# + str( len(this_lfeat_symbol))\
# + ' VS. '\
# + str(len(self._lfeat_type_list)))
index = 0
while index < len(lfeat_symbol_separate):
lfeat_symbol_separate[index] = (
lfeat_symbol_separate[index] + this_lfeat_symbol[index] + " "
)
index = index + 1
input_and_label_data = []
index = 0
while index < len(self._lfeat_type_list):
sequence = self.encode_sub_unit(
lfeat_symbol_separate[index].strip(), self._lfeat_type_list[index]
)
sequence_array = np.asarray(sequence, dtype=np.int32)
input_and_label_data.append(sequence_array)
index = index + 1
# # lfeat_type = 'emo_category'
# input_and_label_data.append(lfeat_symbol_separate[index].strip())
# index = index + 1
#
# # lfeat_type = 'speaker'
# input_and_label_data.append(lfeat_symbol_separate[index].strip())
return input_and_label_data
def decode_symbol_sequence(self, sequence):
result = []
for i, lfeat_type in enumerate(self._lfeat_type_list):
s = ""
sequence_item = sequence[i].tolist()
if lfeat_type == "sy":
s = self.decode_sy(sequence_item)
elif lfeat_type == "byte_index":
s = self.decode_byte_index(sequence_item)
elif lfeat_type == "tone":
s = self.decode_tone(sequence_item)
elif lfeat_type == "syllable_flag":
s = self.decode_syllable_flag(sequence_item)
elif lfeat_type == "word_segment":
s = self.decode_word_segment(sequence_item)
elif lfeat_type == "emo_category":
s = self.decode_emo_category(sequence_item)
elif lfeat_type == "speaker_category":
s = self.decode_speaker_category(sequence_item)
else:
raise Exception("Unknown lfeat type: %s" % lfeat_type)
result.append("%s:%s" % (lfeat_type, s))
return result
def encode_sub_unit(self, this_lfeat_symbol, lfeat_type):
sequence = []
if lfeat_type == "sy":
this_lfeat_symbol = this_lfeat_symbol.strip().split(" ")
this_lfeat_symbol_format = ""
index = 0
while index < len(this_lfeat_symbol):
this_lfeat_symbol_format = (
this_lfeat_symbol_format
+ "{"
+ this_lfeat_symbol[index]
+ "}"
+ " "
)
index = index + 1
sequence = self.encode_text(this_lfeat_symbol_format, self._cleaner_names)
elif lfeat_type == "byte_index":
sequence = self.encode_byte_index(this_lfeat_symbol)
elif lfeat_type == "tone":
sequence = self.encode_tone(this_lfeat_symbol)
elif lfeat_type == "syllable_flag":
sequence = self.encode_syllable_flag(this_lfeat_symbol)
elif lfeat_type == "word_segment":
sequence = self.encode_word_segment(this_lfeat_symbol)
elif lfeat_type == "emo_category":
sequence = self.encode_emo_category(this_lfeat_symbol)
elif lfeat_type == "speaker_category":
sequence = self.encode_speaker_category(this_lfeat_symbol)
else:
raise Exception("Unknown lfeat type: %s" % lfeat_type)
return sequence
def encode_text(self, text, cleaner_names):
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += self.encode_sy(_clean_text(text, cleaner_names))
break
sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names))
sequence += self.encode_arpanet(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(self._sy_to_id["~"])
return sequence
def encode_sy(self, sy):
return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)]
def decode_sy(self, id):
s = self._id_to_sy[id]
if len(s) > 1 and s[0] == "@":
s = s[1:]
return s
def should_keep_sy(self, s):
return s in self._sy_to_id and s != "_" and s != "~"
def encode_arpanet(self, text):
return self.encode_sy(["@" + s for s in text.split()])
def encode_byte_index(self, byte_index):
byte_indices = ["@" + s for s in byte_index.strip().split(" ")]
sequence = []
for this_byte_index in byte_indices:
sequence.append(self._byte_index_to_id[this_byte_index])
sequence.append(self._byte_index_to_id["~"])
return sequence
def decode_byte_index(self, id):
s = self._id_to_byte_index[id]
if len(s) > 1 and s[0] == "@":
s = s[1:]
return s
def encode_tone(self, tone):
tones = tone.strip().split(" ")
sequence = []
for this_tone in tones:
sequence.append(self._tone_to_id[this_tone])
sequence.append(self._tone_to_id["~"])
return sequence
def decode_tone(self, id):
return self._id_to_tone[id]
def encode_syllable_flag(self, syllable_flag):
syllable_flags = syllable_flag.strip().split(" ")
sequence = []
for this_syllable_flag in syllable_flags:
sequence.append(self._syllable_flag_to_id[this_syllable_flag])
sequence.append(self._syllable_flag_to_id["~"])
return sequence
def decode_syllable_flag(self, id):
return self._id_to_syllable_flag[id]
def encode_word_segment(self, word_segment):
word_segments = word_segment.strip().split(" ")
sequence = []
for this_word_segment in word_segments:
sequence.append(self._word_segment_to_id[this_word_segment])
sequence.append(self._word_segment_to_id["~"])
return sequence
def decode_word_segment(self, id):
return self._id_to_word_segment[id]
def encode_emo_category(self, emo_type):
emo_categories = emo_type.strip().split(" ")
sequence = []
for this_category in emo_categories:
sequence.append(self._emo_category_to_id[this_category])
sequence.append(self._emo_category_to_id["~"])
return sequence
def decode_emo_category(self, id):
return self._id_to_emo_category[id]
def encode_speaker_category(self, speaker):
speakers = speaker.strip().split(" ")
sequence = []
for this_speaker in speakers:
sequence.append(self._speaker_to_id[this_speaker])
sequence.append(self._speaker_to_id["~"])
return sequence
def decode_speaker_category(self, id):
return self._id_to_speaker[id]
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(
num, andword="", zero="oh", group=2
).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
import logging
import subprocess
def logging_to_file(log_file):
logger = logging.getLogger()
handler = logging.FileHandler(log_file)
formatter = logging.Formatter(
"%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def get_git_revision_short_hash():
return (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode("ascii")
.strip()
)
def get_git_revision_hash():
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
import matplotlib
matplotlib.use("Agg") # NOQA: E402
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("Please install matplotlib.")
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
im = ax.imshow(alignment, aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=ax)
xlabel = "Input timestep"
if info is not None:
xlabel += "\t" + info
plt.xlabel(xlabel)
plt.ylabel("Output timestep")
fig.canvas.draw()
plt.close()
return fig
from fastapi import FastAPI, File, UploadFile, Path
import uvicorn
from fastapi.responses import FileResponse
import os
from typing import List
from wav_to_label import wav_to_label
from kantts.bin.train_sambert import train as train_sambert
from text_to_wav_trans import text_to_wav as text_to_wav_trans
from text_to_wav_onnx import text_to_wav_onnx
wav_dir = "./Data/ptts_spk0_wav"
txt_dir = "./Data"
output_dir = "./res/ptts_syn"
# 创建FastAPI实例
app = FastAPI()
@app.get("/")
def get_root():
"""
注册一个根路径
"""
return {"message": "Welcome to try: Personal Text To Speech !"}
@app.post("/uploadwavs")
async def uploadwavs(files: List[UploadFile] = File(...)):
# 将文件保存到指定目录;文件路径=目录+文件名
if not os.path.exists(wav_dir):
os.makedirs(wav_dir, exist_ok=True)
for file in files:
with open(os.path.join(wav_dir, file.filename), "wb") as f:
f.write(await file.read())
return {"msg": "File upload success in directory 'Data/ptts_spk0_wav/'"}
@app.post("/uploadtxt")
async def uploadtxt(file: UploadFile = File(...)):
# 将文件保存到指定目录;文件路径=目录+文件名
if not os.path.exists(txt_dir):
os.makedirs(txt_dir, exist_ok=True)
with open(os.path.join(txt_dir, file.filename), "wb") as f:
f.write(await file.read())
return {"msg": "File upload success in directory 'Data/'"}
@app.get("/listfiles/{dirpath:path}")
async def listfiles(dirpath: str):
# return {"file_path": file_path}
res = os.listdir(dirpath)
return {"files": res}
@app.get("/cleardir")
async def cleardir():
for filename in os.listdir(wav_dir):
file_path = os.path.join(wav_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except Exception as e:
print(f"Error: {e}")
return {"msg": "Director cleared and empty"}
@app.get("/deletetxt")
async def deletetxt():
# 指定要删除的文件路径
file_path = "./Data/test.txt"
# 检查文件是否存在
if os.path.isfile(file_path):
# 删除文件
os.remove(file_path)
return {"msg": "File remoeved"}
else:
return {"msg": "File not exist"}
@app.get("/downloadfile/{filename}")
async def downloadfile(filename: str):
file_path = os.path.join("res/ptts_syn/res_wavs", filename)
if os.path.exists(file_path):
return FileResponse(file_path)
else:
return {
"msg": "File not exis"
}
@app.get("/wav2label")
async def wav2label():
report = wav_to_label(wav_dir)
return report
@app.get("/featsextract")
async def featsextract():
# 执行Shell脚本
f = os.popen('./feats_extract.sh')
return f.read()
@app.get("/trainsambert")
async def trainsambert():
train_sambert(
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/config.yaml",
"training_stage/ptts_feats",
"training_stage/ptts_sambert_ckpt",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_2400000.pth"
)
return {"msg": "Traing finished"}
@app.get("/modeltransform")
async def modeltransform():
text_to_wav_trans(
"./Data/test.txt",
"res/ptts_syn_one",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
"training_stage/ptts_sambert_ckpt/ckpt/checkpoint_2402200.pth",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth",
"F7",
"training_stage/ptts_feats/se/se.npy"
)
return {"msg": "Model transform finished"}
# from enum import Enum
# class TargetRate(Enum):
# rate0 = 1.0
# rate1 = 0.5
# rate2 = 0.75
# rate3 = 1.25
# rate4 = 1.5
# rate5= 1.75
# rate6 = 2.0
@app.get("/text2wav/{targetrate}")
async def text2wav(targetrate: float=1.0):
text_to_wav_onnx(
"./Data/test.txt",
output_dir,
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
"sambert_onnx/text_encoder.onnx",
"sambert_onnx/variance_adaptor_dict.pt",
"sambert_onnx/mel_decoder_dict.pt",
"sambert_onnx/mel_postnet.onnx",
"training_stage/ptts_sambert_ckpt/config.yaml",
"hifigan_onnx/hifigan.onnx",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/config.yaml",
targetrate,
"F7",
"training_stage/ptts_feats/se/se.npy"
)
return {"msg": "Text to wav finished"}
@app.get("/oneclickstart/{targetrate}")
async def oneclickstart(targetrate: float=1.0):
report = wav_to_label(wav_dir)
f = os.popen('./feats_extract.sh')
train_sambert(
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/config.yaml",
"training_stage/ptts_feats",
"training_stage/ptts_sambert_ckpt",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_2400000.pth"
)
text_to_wav_trans(
"./Data/test.txt",
"res/ptts_syn_one",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
"training_stage/ptts_sambert_ckpt/ckpt/checkpoint_2402200.pth",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth",
"F7",
"training_stage/ptts_feats/se/se.npy"
)
text_to_wav_onnx(
"./Data/test.txt",
output_dir,
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
"sambert_onnx/text_encoder.onnx",
"sambert_onnx/variance_adaptor_dict.pt",
"sambert_onnx/mel_decoder_dict.pt",
"sambert_onnx/mel_postnet.onnx",
"training_stage/ptts_sambert_ckpt/config.yaml",
"hifigan_onnx/hifigan.onnx",
"speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/config.yaml",
targetrate,
"F7",
"training_stage/ptts_feats/se/se.npy"
)
return {"msg": "Text to wav finished"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
# 模型唯一标识
modelCode = 534
# 模型名称
modelName = sambert-hifigan
# 模型描述
modelDescription = Personal TTS,即个性化语音合成,基于阿里KAN-TTS框架实现语音克隆。
# 应用场景
appScenario = 推理,训练,语音,交通,
# 框架类型
frameType = pytorch
#! /usr/bin/env bash
# 可能微调train_max_steps需要传参--sed
# Function to display usage information
usage() {
echo "Usage: $0 <valueA> <valueB> [<valueC>]"
exit 1
}
target_rate=1.0
# Check if required parameters are provided
if [ -z "$1" ] || [ -z "$2" ]; then
echo "Error: Missing required parameters."
usage
exit 1
fi
# Assign input parameters to variables
spk_wav="$1"
text="$2"
# Check if optional parameters are provided
if [ -n "$3" ]; then
target_rate=$3
fi
# Display the input parameters
echo "spk_wav: $spk_wav"
echo "text: $text"
echo "target_rate: $target_rate"
# Add your script logic here
# 数据自动标注
echo "**********************Start of wav to label**************************************"
echo "*********************************************************************************"
python3 wav_to_label.py --wav_data ${spk_wav}
# 特征提取
echo "**********************Start of feats extract**************************************"
echo "**********************************************************************************"
bash feats_extract.sh
# 训练声学模型
echo "**********************Start of train sambert**************************************"
echo "**********************************************************************************"
HIP_VISIBLE_DEVICES=0 python3 kantts/bin/train_sambert.py \
--model_config speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/config.yaml \
--root_dir training_stage/ptts_feats \
--stage_dir training_stage/ptts_sambert_ckpt \
--resume_path speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_*.pth
sambert_onnx_folder=sambert_onnx
mkdir "$sambert_onnx_folder"
hifigan_onnx_folder=hifigan_onnx
mkdir "$hifigan_onnx_folder"
# 进行模型转换,运行一次pt模型。
echo "**********************Start of model transform**************************************"
echo "*************************************************************************************"
HIP_VISIBLE_DEVICES=0 python3 text_to_wav_trans.py \
--txt ${text} \
--output_dir res/ptts_syn_one \
--res_zip speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip \
--am_ckpt training_stage/ptts_sambert_ckpt/ckpt/checkpoint_2402200.pth \
--voc_ckpt speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth \
--speaker F7 \
--se_file training_stage/ptts_feats/se/se.npy
# 运行合成语音
echo "**********************Start of text to wav*******************************************"
echo "*************************************************************************************"
HIP_VISIBLE_DEVICES=0 python3 text_to_wav_onnx.py \
--txt ${text} \
--output_dir res/ptts_syn \
--res_zip speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip \
--text_encoder_onnx sambert_onnx/text_encoder.onnx \
--variance_adaptor_pt sambert_onnx/variance_adaptor_dict.pt \
--mel_decoder_pt sambert_onnx/mel_decoder_dict.pt \
--mel_postnet_onnx sambert_onnx/mel_postnet.onnx \
--am_config training_stage/ptts_sambert_ckpt/config.yaml \
--voc_onnx hifigan_onnx/hifigan.onnx \
--voc_config speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/config.yaml \
--target_rate ${target_rate} \
--speaker F7 \
--se_file training_stage/ptts_feats/se/se.npy
from setuptools import find_packages, setup
version = "0.0.1"
with open("README.md", "r", encoding="utf-8") as readme_file:
README = readme_file.read()
setup(
name="kantts",
version=version,
url="https://github.com/AlibabaResearch/KAN-TTS",
author="Jin",
description="Alibaba DAMO Speech-Lab Text to Speech deeplearning toolchain",
long_description=README,
long_description_content_type="text/markdown",
license="MIT",
# cython
# include_dirs=numpy.get_include(),
# ext_modules=find_cython_extensions(),
# package
include_package_data=True,
packages=find_packages(include=["kantts*"]),
project_urls={
"Documentation": "https://github.com/AlibabaResearch/KAN-TTS/wiki",
"Tracker": "",
"Repository": "https://github.com/AlibabaResearch/KAN-TTS",
"Discussions": "",
},
python_requires=">=3.7.0, <3.9",
classifiers=[
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Operating System :: POSIX :: Linux",
"License :: OSI Approved :: MIT License",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Multimedia :: Sound/Audio :: Speech",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
zip_safe=False,
)
#!/usr/bin/env python3
import onnxruntime
import zipfile
from glob import glob
try:
from kantts.utils.ling_unit import text_to_mit_symbols as text_to_symbols
from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit
from kantts.models.sambert.kantts_sambert_divide import VarianceAdaptor2, MelPNCADecoder
except ImportError:
raise ImportError("Please install kantts.")
try:
from kantts.utils.log import logging_to_file
except ImportError:
raise ImportError("Please install kantts.")
import os
import sys
import argparse
import torch
import soundfile as sf
import yaml
import logging
import numpy as np
import time
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_feature=None):
if norm_type == 'mean_std':
f0_mvn = f0_feature
f0 = mel[:, -2]
uv = mel[:, -1]
uv[uv < uv_threshold] = 0.0
uv[uv >= uv_threshold] = 1.0
f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :]
f0[f0 < f0_threshold] = f0_threshold
mel[:, -2] = f0
mel[:, -1] = uv
else: # global
f0_global_max_min = f0_feature
f0 = mel[:, -2]
uv = mel[:, -1]
uv[uv < uv_threshold] = 0.0
uv[uv >= uv_threshold] = 1.0
f0 = f0 * (f0_global_max_min[0] - f0_global_max_min[1]) + f0_global_max_min[1]
f0[f0 < f0_threshold] = f0_threshold
mel[:, -2] = f0
mel[:, -1] = uv
return mel
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = (
torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
)
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask
def hifigan_infer(input_mel, onnx_file, output_dir, config=None):
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda", 0)
# device = torch.device("cpu")
if config is not None:
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
else:
config_path = os.path.join(
os.path.dirname(os.path.dirname(ckpt)), "config.yaml"
)
if not os.path.exists(config_path):
raise ValueError("config file not found: {}".format(config_path))
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
# for key, value in config.items():
# logging.info(f"{key} = {value}")
# check directory existence
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logging_to_file(os.path.join(output_dir, "stdout.log"))
if os.path.isfile(input_mel):
mel_lst = [input_mel]
elif os.path.isdir(input_mel):
mel_lst = glob(os.path.join(input_mel, "*.npy"))
else:
raise ValueError("input_mel should be a file or a directory")
# model = load_model(ckpt_path, config)
# logging.info(f"Loaded model parameters from {ckpt_path}.")
# model.remove_weight_norm()
# model = model.eval().to(device)
# providers=['CUDAExecutionProvider', {'device_id': 1}]
# providers=['CPUExecutionProvider'] # 这个是默认
providers = ['ROCMExecutionProvider']
ort_session = onnxruntime.InferenceSession(onnx_file, providers=providers)
print(ort_session.get_providers())
# with torch.no_grad():
# pcm_len = 0
# i = 0 # 转onnx控制模型运行一次
# for mel in mel_lst:
# if i>0:
# break
# i = i+1
# utt_id = os.path.splitext(os.path.basename(mel))[0]
# mel_data = np.load(mel)
# if model.nsf_enable:
# mel_data = binarize(mel_data)
# generate
# mel_data = torch.tensor(mel_data, dtype=torch.float).to(device)
# (T, C) -> (B, C, T)
# mel_data = mel_data.transpose(1, 0).unsqueeze(0)
# GPU预热
# for _ in range(10):
# _ = model(mel_data)
# 测速
# iterations = 100
# times = torch.zeros(iterations) # 存储每轮iteration的时间
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# for iter in range(iterations):
# starter.record()
# _ = model(mel_data)
# ender.record()
# 同步GPU时间
# torch.cuda.synchronize()
# cur_time = starter.elapsed_time(ender) # 计算时间
# times[iter] = cur_time
# print(cur_time)
# mean_time = times.mean().item()
# print("hifigan pth file infer single time: {:.6f}".format(mean_time))
# y = model(mel_data)
start = time.time()
pcm_len = 0
for mel in mel_lst:
start1 = time.time()
utt_id = os.path.splitext(os.path.basename(mel))[0]
logging.info("Inference sentence: {}".format(utt_id))
mel_data = np.load(mel)
# generate
mel_data = torch.tensor(mel_data, dtype=torch.float).to(device)
# (T, C) -> (B, C, T)
mel_data = mel_data.transpose(1, 0).unsqueeze(0)
ort_inputs = {'mel_data': mel_data.cpu().numpy()}
# GPU预热
for _ in range(50):
_ = ort_session.run(['y'], ort_inputs)
# 测速
iterations = 100
times = torch.zeros(iterations) # 存储每轮iteration的时间
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
for iter in range(iterations):
starter.record()
ort_list = ort_session.run(['y'], ort_inputs)
ender.record()
# 同步GPU时间
torch.cuda.synchronize()
cur_time = starter.elapsed_time(ender) # 计算时间
times[iter] = cur_time
mean_time = times.mean().item()
print("hifigan-onnx infer single time: {:.6f} ms".format(mean_time))
# logging.info("hifigan is running...")
# ort_list = ort_session.run(['y'], ort_inputs)
# PyTorch模型转换成 ONNX 格式
# x0 = mel_data
# torch.onnx.export(
# model,
# x0,
# "hifigan.onnx",
# opset_version=11,
# input_names=['mel_data'],
# output_names=['y']
# )
# if hasattr(model, "pqmf"):
# y = model.pqmf.synthesis(y)
# print("hifigan infer single time: {:.6f}".format(mean_time))
# ort_y = ort_y.view(-1).cpu().numpy()
ort_y = torch.from_numpy(ort_list[0]).view(-1).cpu().numpy()
pcm_len += len(ort_y)
# save as PCM 16 bit wav file
# samplerate = 16000
sf.write(
os.path.join(output_dir, f"{utt_id}_gen.wav"),
ort_y,
config["audio_config"]["sampling_rate"],
"PCM_16",
)
total_elapsed = time.time() - start1
print(f'Vocoder infer single time: {total_elapsed} seconds')
rtf = (time.time() - start) / (
pcm_len / config["audio_config"]["sampling_rate"]
)
# report average RTF
logging.info(
f"Finished generation of {len(mel_lst)} utterances (RTF = {rtf:.03f})."
)
def am_infer_divide(sentence,
text_encoder_onnx,
variance_adaptor_ckpt,
mel_decoder_ckpt,
mel_postnet_onnx,
output_dir,
target_rate=1.0,
se_file=None,
config=None):
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda", 0)
# device = torch.device("cpu")
if config is not None:
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
# else:
# am_config_file = os.path.join(
# os.path.dirname(os.path.dirname(ckpt)), "config.yaml"
# )
# with open(am_config_file, "r") as f:
# config = yaml.load(f, Loader=yaml.Loader)
ling_unit = KanTtsLinguisticUnit(config)
ling_unit_size = ling_unit.get_unit_size()
config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size)
se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False)
se = np.load(se_file) if se_enable else None
# nsf
nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False)
if nsf_enable:
nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_norm_type", "mean_std")
if nsf_norm_type == "mean_std":
f0_mvn_file = os.path.join(
os.path.dirname(os.path.dirname(ckpt)), "mvn.npy"
)
f0_feature = np.load(f0_mvn_file)
else: # global
nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_minimum", 30.0)
nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_maximum", 730.0)
f0_feature = [nsf_f0_global_maximum, nsf_f0_global_minimum]
# model, _, _ = model_builder(config, device)
# fsnet = model["KanTtsSAMBERT"]
logging.info("ort_sess is building...")
providers = ['ROCMExecutionProvider']
logging.info("text_encoder_ort_sess is building...")
text_enxoder_ort_sess = onnxruntime.InferenceSession(text_encoder_onnx, providers=providers)
print(text_enxoder_ort_sess.get_providers())
# variance_adaptor_ort_sess = onnxruntime.InferenceSession(variance_adaptor_onnx, providers=providers)
# mel_decoder_ort_sess = onnxruntime.InferenceSession(mel_decoder_onnx, providers=providers)
logging.info("mel_postnet_ort_sess is building...")
mel_postnet_ort_sess = onnxruntime.InferenceSession(mel_postnet_onnx, providers=providers)
# variance_adaptor部分不用onnx,用pt
variance_adaptor = VarianceAdaptor2(config["Model"]["KanTtsSAMBERT"]["params"]).to(device)
logging.info("Loading checkpoint: {}".format(variance_adaptor_ckpt))
variance_adaptor_state_dict = torch.load(variance_adaptor_ckpt)
variance_adaptor.load_state_dict(variance_adaptor_state_dict, strict=False)
# mel_decoder部分不用onnx,用pt
mel_decoder =MelPNCADecoder(config["Model"]["KanTtsSAMBERT"]["params"]).to(device)
logging.info("Loading checkpoint: {}".format(mel_decoder_ckpt))
mel_decoder_state_dict = torch.load(mel_decoder_ckpt)
mel_decoder.load_state_dict(mel_decoder_state_dict, strict=False)
results_dir = os.path.join(output_dir, "feat")
os.makedirs(results_dir, exist_ok=True)
# fsnet.eval()
# i = 0 # 转onnx控制模型运行一次
with open(sentence, encoding="utf-8") as f:
for line in f:
# if i > 0:
# break
# i = i + 1
start = time.time()
line = line.strip().split("\t")
logging.info("Inference sentence: {}".format(line[0]))
mel_path = "%s/%s_mel.npy" % (results_dir, line[0])
dur_path = "%s/%s_dur.txt" % (results_dir, line[0])
f0_path = "%s/%s_f0.txt" % (results_dir, line[0])
energy_path = "%s/%s_energy.txt" % (results_dir, line[0])
with torch.no_grad():
# mel, mel_post, dur, f0, energy = am_synthesis(line[1], fsnet, ling_unit, device, se=se)
inputs_feat_lst = ling_unit.encode_symbol_sequence(line[1])
inputs_feat_index = 0
if ling_unit.using_byte():
inputs_byte_index = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_ling = torch.stack([inputs_byte_index], dim=-1).unsqueeze(0)
else:
inputs_sy = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_feat_index = inputs_feat_index + 1
inputs_tone = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_feat_index = inputs_feat_index + 1
inputs_syllable = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_feat_index = inputs_feat_index + 1
inputs_ws = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_ling = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], dim=-1
).unsqueeze(0)
inputs_feat_index = inputs_feat_index + 1
inputs_emo = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index])
.long()
.to(device)
.unsqueeze(0)
)
inputs_feat_index = inputs_feat_index + 1
se_enable = False if se is None else True
if se_enable:
inputs_spk = (
torch.from_numpy(se.repeat(len(inputs_feat_lst[inputs_feat_index]), axis=0))
.float()
.to(device)
.unsqueeze(0)[:, :-1, :]
)
else:
inputs_spk = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index])
.long()
.to(device)
.unsqueeze(0)[:, :-1]
)
inputs_len = (
torch.zeros(1).long().to(device) + inputs_emo.size(1) - 1
) # minus 1 for "~"
inputs_ling = inputs_ling[:, :-1, :]
inputs_emotion = inputs_emo[:, :-1]
inputs_speaker = inputs_spk
inputs_lengths = inputs_len
batch_size = inputs_ling.size(0)
inputs_ling_masks = get_mask_from_lengths(inputs_lengths, max_len=inputs_ling.size(1))
text_enxoder_inputs = {'inputs_ling': inputs_ling.cpu().numpy(),
'inputs_emotion': inputs_emotion.cpu().numpy(),
'inputs_speaker': inputs_speaker.cpu().numpy(),
'inputs_ling_masks': inputs_ling_masks.cpu().numpy(),
}
# # GPU预热
# for _ in range(50):
# (
# _0,
# _1,
# _2,
# _3
# ) = text_enxoder_ort_sess.run(['text_hid',
# 'ling_embedding',
# 'emo_hid',
# 'spk_hid'], text_enxoder_inputs)
# _ = fsnet(
# inputs_ling[:, :-1, :],
# inputs_emo[:, :-1],
# inputs_spk,
# inputs_len,)
# inputs_ling = inputs_ling[:, :-1, :]
# inputs_emotion = inputs_emo[:, :-1]
# inputs_speaker = inputs_spk
# inputs_lengths = inputs_len
# 开始text_encoder
# text_hid, ling_embedding, emo_hid, spk_hid = text_encoder(
# inputs_ling,
# inputs_emotion,
# inputs_speaker,
# inputs_ling_masks=inputs_ling_masks,
# # return_attns=True)
# # 测速
# iterations = 100
# times = torch.zeros(iterations) # 存储每轮iteration的时间
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# for iter in range(iterations):
# starter.record()
# # _ = fsnet(
# # inputs_ling[:, :-1, :],
# # inputs_emo[:, :-1],
# # inputs_spk,
# # inputs_len,)
# # logging.info("text_encoder is running...")
# # text_enxoder_inputs = {'inputs_ling': inputs_ling.cpu().numpy(),
# # 'inputs_emotion': inputs_emotion.cpu().numpy(),
# # 'inputs_speaker': inputs_speaker.cpu().numpy(),
# # 'inputs_ling_masks': inputs_ling_masks.cpu().numpy(),
# # }
# (
# text_hid,
# ling_embedding,
# emo_hid,
# spk_hid
# ) = text_enxoder_ort_sess.run(['text_hid',
# 'ling_embedding',
# 'emo_hid',
# 'spk_hid'], text_enxoder_inputs
# )
# ender.record()
# # 同步GPU时间
# torch.cuda.synchronize()
# cur_time = starter.elapsed_time(ender) # 计算时间
# times[iter] = cur_time
# mean_time = times.mean().item()
# print("text_enxoder-onnx single time: {:.6f} ms".format(mean_time))
(
text_hid,
ling_embedding,
emo_hid,
spk_hid
) = text_enxoder_ort_sess.run(
['text_hid',
'ling_embedding',
'emo_hid',
'spk_hid'],
text_enxoder_inputs
)
inter_lengths = inputs_lengths
inter_masks = get_mask_from_lengths(inter_lengths, max_len=text_hid.shape[1])
# output_masks = None
# logging.info("variance_adaptor is running...")
# # 测速
# iterations = 100
# times = torch.zeros(iterations) # 存储每轮iteration的时间
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# for iter in range(iterations):
# starter.record()
# # 开始variance adaptorpt
# (
# LR_text_outputs,LR_emo_outputs,
# LR_spk_outputs,
# LR_length_rounded,
# log_duration_predictions,
# pitch_predictions,
# energy_predictions,
# ) = variance_adaptor(
# torch.from_numpy(text_hid).to(device),
# torch.from_numpy(emo_hid).to(device),
# torch.from_numpy(spk_hid).to(device),
# masks=inter_masks,
# # output_masks=output_masks,
# # duration_targets=None,
# # pitch_targets=None,
# # energy_targets=None,
# )
# ender.record()
# # 同步GPU时间
# torch.cuda.synchronize()
# cur_time = starter.elapsed_time(ender) # 计算时间
# times[iter] = cur_time
# mean_time = times.mean().item()
# print("variance_adaptor-pytorch single time: {:.6f} ms".format(mean_time))
# 开始variance adaptorpt
(
LR_text_outputs,LR_emo_outputs,
LR_spk_outputs,
LR_length_rounded,
log_duration_predictions,
pitch_predictions,
energy_predictions,
) = variance_adaptor(
torch.from_numpy(text_hid).to(device),
torch.from_numpy(emo_hid).to(device),
torch.from_numpy(spk_hid).to(device),
scale=1/target_rate,
masks=inter_masks,
# output_masks=output_masks,
# duration_targets=None,
# pitch_targets=None,
# energy_targets=None,
)
# variance_adaptor_inputs = {'text_hid': text_hid,
# 'emo_hid': emo_hid,
# 'spk_hid': spk_hid,
# 'inter_masks': inter_masks.cpu().numpy(),
# }
# (
# LR_text_outputs, LR_emo_outputs,
# LR_spk_outputs,
# LR_length_rounded,
# log_duration_predictions,
# pitch_predictions,
# energy_predictions,
# ) = variance_adaptor_ort_sess.run(['LR_text_outputs',
# 'LR_emo_outputs',
# 'LR_spk_outputs',
# 'LR_length_rounded',
# 'log_duration_predictions',
# 'pitch_predictions',
# 'energy_predictions'], variance_adaptor_inputs)
output_masks = get_mask_from_lengths(LR_length_rounded, max_len=LR_text_outputs.shape[1])
# lfr_masks = None
outputs_per_step = config["Model"]["KanTtsSAMBERT"]["params"]["outputs_per_step"]
r = outputs_per_step
# LFR with the factor of outputs_per_step
LFR_text_inputs = LR_text_outputs.contiguous().view(batch_size, -1, r * text_hid.shape[
-1]) # [1,153,32]->[1,51,96]
LFR_emo_inputs = LR_emo_outputs.contiguous().view(batch_size, -1, r * emo_hid.shape[-1])[
:, :, : emo_hid.shape[-1]]
LFR_spk_inputs = LR_spk_outputs.contiguous().view(batch_size, -1, r * spk_hid.shape[-1])[
:, :, : spk_hid.shape[-1]] # [1,153,192]->[1,51,192]
memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], dim=2)
x_band_width = int((torch.exp(log_duration_predictions) - 1).max() / r + 0.5)
h_band_width = x_band_width
# logging.info("mel_decoder is running...")
# # 测速
# iterations = 100
# times = torch.zeros(iterations) # 存储每轮iteration的时间
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# for iter in range(iterations):
# starter.record()
# # 开始mel_decoder
# dec_outputs = mel_decoder(
# memory,
# x_band_width,
# h_band_width,
# # target=None,
# # mask=lfr_masks,
# # return_attns=True,
# )
# ender.record()
# # 同步GPU时间
# torch.cuda.synchronize()
# cur_time = starter.elapsed_time(ender) # 计算时间
# times[iter] = cur_time
# mean_time = times.mean().item()
# print("mel_decoder-pytorch single time: {:.6f} ms".format(mean_time))
# 开始mel_decoder
dec_outputs = mel_decoder(
memory,
x_band_width,
h_band_width,
# target=None,
# mask=lfr_masks,
# return_attns=True,
)
# mel_decoder_inputs = {'memory': memory.cpu().numpy(),
# 'x_band_width': np.array(x_band_width),
# 'h_band_width': np.array(x_band_width),
# }
# dec_outputs = mel_decoder_ort_sess.run(['dec_outputs'], mel_decoder_inputs)
d_mel = config["Model"]["KanTtsSAMBERT"]["params"]["num_mels"]
# De-LFR with the factor of outputs_per_step
dec_outputs = dec_outputs[0].contiguous().view(batch_size, -1, d_mel) # [1,51,246]->[1,153,82]
if output_masks is not None:
dec_outputs = dec_outputs.masked_fill(output_masks.unsqueeze(-1), 0)
# logging.info("mel_postnet is running...")
# 开始mel_postnet
# postnet_outputs = mel_postnet(dec_outputs, output_masks) + dec_outputs
# postnet_outputs = mel_postnet(dec_outputs, output_masks)
mel_decoder_inputs = {'dec_outputs': dec_outputs.cpu().numpy(),
'output_masks': output_masks.cpu().numpy(),
}
# # 测速
# iterations = 100
# times = torch.zeros(iterations) # 存储每轮iteration的时间
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# for iter in range(iterations):
# starter.record()
# postnet_outputs = mel_postnet_ort_sess.run(['postnet_outputs'], mel_decoder_inputs)
# ender.record()
# # 同步GPU时间
# torch.cuda.synchronize()
# cur_time = starter.elapsed_time(ender) # 计算时间
# times[iter] = cur_time
# mean_time = times.mean().item()
# print("mel_postnet-onnx single time: {:.6f} ms".format(mean_time))
postnet_outputs = mel_postnet_ort_sess.run(
['postnet_outputs'],
mel_decoder_inputs
)
postnet_outputs = torch.from_numpy(postnet_outputs[0]).to(device) + dec_outputs
if output_masks is not None:
postnet_outputs = postnet_outputs.masked_fill(output_masks.unsqueeze(-1), 0)
# 至此sambert forward开始返回值
# return torch.tensor(x_band_width), torch.tensor(h_band_width), dec_outputs, postnet_outputs,\
# LR_length_rounded, log_duration_predictions, pitch_predictions, energy_predictions
valid_length = int(LR_length_rounded[0].item())
dec_outputs = dec_outputs[0, :valid_length, :].cpu().numpy()
postnet_outputs = postnet_outputs[0, :valid_length, :].cpu().numpy()
duration_predictions = (
(torch.exp(log_duration_predictions) - 1 + 0.5).long().squeeze().cpu().numpy())
pitch_predictions = pitch_predictions.squeeze().cpu().numpy()
energy_predictions = energy_predictions.squeeze().cpu().numpy()
logging.info("x_band_width:{}, h_band_width: {}".format(x_band_width, h_band_width))
# return (
# dec_outputs,
# postnet_outputs,
# duration_predictions,
# pitch_predictions,
# energy_predictions,
# ) # 对应mel, mel_post, dur, f0, energy
mel, mel_post, dur, f0, energy = dec_outputs, postnet_outputs, duration_predictions, pitch_predictions, energy_predictions
if nsf_enable:
mel_post = denorm_f0(mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature)
np.save(mel_path, mel_post)
np.savetxt(dur_path, dur)
np.savetxt(f0_path, f0)
np.savetxt(energy_path, energy)
total_elapsed = time.time() - start
print(f'AM infer single time: {total_elapsed} seconds')
def concat_process(chunked_dir, output_dir):
wav_files = sorted(glob(os.path.join(chunked_dir, "*.wav")))
sentence_sil = 0.28 # seconds
end_sil = 0.05 # seconds
cnt = 0
wav_concat = None
main_id, sub_id = 0, 0
while cnt < len(wav_files):
wav_file = os.path.join(
chunked_dir, "{}_{}_mel_gen.wav".format(main_id, sub_id)
)
if os.path.exists(wav_file):
wav, sr = sf.read(wav_file)
sentence_sil_samples = int(sentence_sil * sr)
end_sil_samples = int(end_sil * sr)
if sub_id == 0:
wav_concat = wav
else:
wav_concat = np.concatenate(
(wav_concat, np.zeros(sentence_sil_samples), wav), axis=0
)
sub_id += 1
cnt += 1
else:
if wav_concat is not None:
wav_concat = np.concatenate(
(wav_concat, np.zeros(end_sil_samples)), axis=0
)
sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr)
main_id += 1
sub_id = 0
wav_concat = None
if cnt == len(wav_files):
wav_concat = np.concatenate((wav_concat, np.zeros(end_sil_samples)), axis=0)
sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr)
def text_to_wav_onnx(
text_file,
output_dir,
resources_zip_file,
text_encoder_onnx,
variance_adaptor_pt,
mel_decoder_onnx,
mel_postnet_onnx,
am_config_file,
voc_onnx,
voc_config_file,
target_rate=1.0,
speaker=None,
se_file=None,
lang="PinYin",
):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "res_wavs"), exist_ok=True)
resource_root_dir = os.path.dirname(resources_zip_file)
resource_dir = os.path.join(resource_root_dir, "resource")
if not os.path.exists(resource_dir):
logging.info("Extracting resources...")
with zipfile.ZipFile(resources_zip_file, "r") as zip_ref:
zip_ref.extractall(resource_root_dir)
with open(text_file, "r") as text_data:
texts = text_data.readlines()
logging.info("Converting text to symbols...")
# am_config = os.path.join(os.path.dirname(os.path.dirname(am_ckpt)), "config.yaml")
with open(am_config_file, "r") as f:
am_config = yaml.load(f, Loader=yaml.Loader)
if speaker is None:
speaker = am_config["linguistic_unit"]["speaker_list"].split(",")[0]
symbols_lst = text_to_symbols(texts, resource_dir, speaker, lang)
symbols_file = os.path.join(output_dir, "symbols.lst")
with open(symbols_file, "w") as symbol_data:
for symbol in symbols_lst:
symbol_data.write(symbol)
logging.info("AM is infering...")
start = time.time()
# am_infer(symbols_file, am_ckpt, output_dir, se_file)
am_infer_divide(symbols_file,
text_encoder_onnx,
variance_adaptor_pt,
mel_decoder_onnx,
mel_postnet_onnx,
output_dir,
target_rate=target_rate,
se_file=se_file,
config=am_config_file
)
total_elapsed = time.time() - start
print(f'AM infer time: {total_elapsed} seconds')
logging.info("Vocoder is infering...")
start = time.time()
hifigan_infer(os.path.join(output_dir, "feat"),
voc_onnx,
output_dir, config=voc_config_file)
total_elapsed = time.time() - start
print(f'Vocoder infer time: {total_elapsed} seconds')
concat_process(output_dir, os.path.join(output_dir, "res_wavs"))
logging.info("Text to wav finished!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Text2wav_onnx")
parser.add_argument("--txt", type=str, required=True, help="Path to text file")
parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory")
parser.add_argument("--res_zip", type=str, required=True, help="Path to resource zip file")
# parser.add_argument("--am_ckpt", type=str, required=True, help="Path to am ckpt file")
parser.add_argument("--text_encoder_onnx", type=str, required=True, help="Path to am -1 file")
parser.add_argument("--variance_adaptor_pt", type=str, required=True, help="Path to am -2 file")
parser.add_argument("--mel_decoder_pt", type=str, required=True, help="Path to am -3 file")
parser.add_argument("--mel_postnet_onnx", type=str, required=True, help="Path to am -4 file")
parser.add_argument("--am_config", type=str, required=True, help="Path to am config file")
parser.add_argument("--voc_onnx", type=str, required=True, help="Path to voc onnx file")
parser.add_argument("--voc_config", type=str, required=True, help="Path to voc config file")
parser.add_argument("--target_rate", type=float, required=False, default=1.0,
choices=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], help="Rate to final wav; optional: 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0")
parser.add_argument("--speaker", type=str, required=False, default=None,
help="The speaker name, default is the first speaker", )
parser.add_argument("--se_file", type=str, required=False, default=None,
help="The speaker embedding file , default is None", )
parser.add_argument("--lang", type=str, default="PinYin",
help="""The language of the text, default is PinYin, other options are:
English,
British,
ZhHK,
WuuShanghai,
Sichuan,
Indonesian,
Malay,
Filipino,
Vietnamese,
Korean,
Russian
""",
)
args = parser.parse_args()
start = time.time()
text_to_wav_onnx(
args.txt,
args.output_dir,
args.res_zip,
# args.am_ckpt,
args.text_encoder_onnx,
args.variance_adaptor_pt,
args.mel_decoder_pt,
args.mel_postnet_onnx,
args.am_config,
args.voc_onnx,
args.voc_config,
args.target_rate,
args.speaker,
args.se_file,
args.lang,
)
total_elapsed = time.time() - start
print(f'text to wave time: {total_elapsed} seconds')
import os
import sys
import argparse
import yaml
import logging
import zipfile
from glob import glob
import soundfile as sf
import numpy as np
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.bin.infer_sambert import am_infer
from kantts.bin.infer_sambert_divide_to_onnx import am_infer_divide
from kantts.bin.infer_hifigan_to_onnx import hifigan_infer
from kantts.utils.ling_unit import text_to_mit_symbols as text_to_symbols
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def concat_process(chunked_dir, output_dir):
wav_files = sorted(glob(os.path.join(chunked_dir, "*.wav")))
print(wav_files)
sentence_sil = 0.28 # seconds
end_sil = 0.05 # seconds
cnt = 0
wav_concat = None
main_id, sub_id = 0, 0
while cnt < len(wav_files):
wav_file = os.path.join(
chunked_dir, "{}_{}_mel_gen.wav".format(main_id, sub_id)
)
if os.path.exists(wav_file):
wav, sr = sf.read(wav_file)
sentence_sil_samples = int(sentence_sil * sr)
end_sil_samples = int(end_sil * sr)
if sub_id == 0:
wav_concat = wav
else:
wav_concat = np.concatenate(
(wav_concat, np.zeros(sentence_sil_samples), wav), axis=0
)
sub_id += 1
cnt += 1
else:
if wav_concat is not None:
wav_concat = np.concatenate(
(wav_concat, np.zeros(end_sil_samples)), axis=0
)
sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr)
main_id += 1
sub_id = 0
wav_concat = None
if cnt == len(wav_files):
wav_concat = np.concatenate((wav_concat, np.zeros(end_sil_samples)), axis=0)
sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr)
def text_to_wav(
text_file,
output_dir,
resources_zip_file,
am_ckpt,
voc_ckpt,
speaker=None,
se_file=None,
lang="PinYin",
):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "res_wavs"), exist_ok=True)
resource_root_dir = os.path.dirname(resources_zip_file)
resource_dir = os.path.join(resource_root_dir, "resource")
if not os.path.exists(resource_dir):
logging.info("Extracting resources...")
with zipfile.ZipFile(resources_zip_file, "r") as zip_ref:
zip_ref.extractall(resource_root_dir)
with open(text_file, "r") as text_data:
texts = text_data.readlines()
logging.info("Converting text to symbols...")
am_config = os.path.join(os.path.dirname(os.path.dirname(am_ckpt)), "config.yaml")
with open(am_config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
if speaker is None:
speaker = config["linguistic_unit"]["speaker_list"].split(",")[0]
symbols_lst = text_to_symbols(texts, resource_dir, speaker, lang)
symbols_file = os.path.join(output_dir, "symbols.lst")
with open(symbols_file, "w") as symbol_data:
for symbol in symbols_lst:
symbol_data.write(symbol)
logging.info("AM is infering...")
# am_infer(symbols_file, am_ckpt, output_dir, se_file)
am_infer_divide(symbols_file, am_ckpt, output_dir, se_file)
logging.info("Vocoder is infering...")
hifigan_infer(os.path.join(output_dir, "feat"), voc_ckpt, output_dir)
concat_process(output_dir, os.path.join(output_dir, "res_wavs"))
# logging.info("Text to wav finished!")
logging.info("Model transform finished!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Text to wav")
parser.add_argument("--txt", type=str, required=True, help="Path to text file")
parser.add_argument(
"--output_dir", type=str, required=True, help="Path to output directory"
)
parser.add_argument(
"--res_zip", type=str, required=True, help="Path to resource zip file"
)
parser.add_argument(
"--am_ckpt", type=str, required=True, help="Path to am ckpt file"
)
parser.add_argument(
"--voc_ckpt", type=str, required=True, help="Path to voc ckpt file"
)
parser.add_argument(
"--speaker",
type=str,
required=False,
default=None,
help="The speaker name, default is the first speaker",
)
parser.add_argument(
"--se_file",
type=str,
required=False,
default=None,
help="The speaker embedding file , default is None",
)
parser.add_argument(
"--lang",
type=str,
default="PinYin",
help="""The language of the text, default is PinYin, other options are:
English,
British,
ZhHK,
WuuShanghai,
Sichuan,
Indonesian,
Malay,
Filipino,
Vietnamese,
Korean,
Russian
""",
)
args = parser.parse_args()
text_to_wav(
args.txt,
args.output_dir,
args.res_zip,
args.am_ckpt,
args.voc_ckpt,
args.speaker,
args.se_file,
args.lang,
)
#!/usr/bin/env python3
# 导入run_auto_label工具, 初次运行会下载相关库文件
from modelscope.tools import run_auto_label
# 运行 autolabel进行自动标注,20句音频的自动标注约4分钟
import os
import argparse
def wav_to_label(wav_data):
work_dir = os.path.join(os.path.split(wav_data)[0], os.path.split(wav_data)[1]+"_autolabel")
os.makedirs(work_dir, exist_ok=True)
ret, report = run_auto_label(input_wav=wav_data,
work_dir=work_dir,
resource_revision='v1.0.7'
)
# print(report)
return report
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="wav2label")
parser.add_argument("--wav_data", type=str, required=True, help="Path to wav data")
args = parser.parse_args()
wav_to_label(args.wav_data)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment