Commit 51782715 authored by liugh5's avatar liugh5
Browse files

update

parent 8b4e9acd
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38
COPY requirements.txt requirements.txt
RUN pip3 install -r requirements.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## 论文 ## 论文
[RobuTrans: A Robust Transformer-Based Text-to-Speech Model](https://ojs.aaai.org/index.php/AAAI/article/view/6337) [RobuTrans: A Robust Transformer-Based Text-to-Speech Model](https://ojs.aaai.org/index.php/AAAI/article/view/6337)
[HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646)
## 模型结构 ## 模型结构
...@@ -30,24 +30,92 @@ ...@@ -30,24 +30,92 @@
``` ```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38 docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38
# <your IMAGE ID>为以上拉取的docker的镜像ID替换 # <your IMAGE ID>为以上拉取的docker的镜像ID替换
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name container_name imageID bash
cd /path/workspace/ cd /path/your_code_data/
首先是KAN-TTS环境搭建: ```
KAN-TTS环境搭建:
```
git clone -b develop https://github.com/alibaba-damo-academy/KAN-TTS.git git clone -b develop https://github.com/alibaba-damo-academy/KAN-TTS.git
cd KAN-TTS cd KAN-TTS
之后拉取预训练模型:
git clone https://www.modelscope.cn/damo/speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k.git
pip3 install -r requirements.txt
``` ```
下载本仓库:
```
git clone http://developer.hpccube.com/codes/modelzoo/sambert-hifigan_pytorch.git
cd sambert-hifigan_pytorch
```
安装相应的python依赖库:
```
pip3 install -r requirements.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
```
### Dockerfile(方式二)
```
docker build -t <image_name> .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name container_name image_name bash
cd /path/your_code_data/
```
之后参考方法一。
###Conda(方式三)
1. 创建conda虚拟环境:
```
conda create -n <env_name> python=3.8
```
2. 关于本项目DCU显卡所需的工具包、深度学习库等均可从[光合](https://www.hpccube.com/sso/login?service=https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:[dtk23.10](https://cancon.hpccube.com:65024/1/main/DTK-23.10)
pytorch:[1.13.1](https://cancon.hpccube.com:65024/4/main/pytorch/dtk23.10)
```
<table><tr><td bgcolor=gray>Tips:以上dtk驱动、python、paddle等DCU相关工具版本需要严格一一对应。</td></tr></table>
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
```
## 数据集 ## 数据集
你可以从ModelScope下载经过阿里标准格式处理的AISHELL-3开源语音合成数据集,用来进行后续操作。如果你只有普通音频格式的数据,那么可以采用PTTS Autolabel自动化标注工具进行格式转换。 你可以从ModelScope下载经过阿里标准格式处理的AISHELL-3开源语音合成数据集,用来进行后续操作。如果你只有普通音频格式的数据,那么可以采用PTTS Autolabel自动化标注工具进行格式转换。
[训练数据](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/download_files/test_female.zip) 可参考的[训练数据](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/download_files/test_female.zip),下载放到目录如Data/ptts_spk0_wav下面。
数据集的目录结构如下:
```
ptts_spk0_wav
├─SSB00180007.wav
├─SSB00180012.wav
├─......
```
数据预处理包括两步:
1. autolabel自动标注
```
python3 wav_to_label.py --wav_data Data/ptts_spk0_wav
```
2. 特征提取
```
bash feats_extract.sh
```
注意修改相应的模型路径。
特征提取只需要运行一小段时间,提取完毕后你会在training_stage/test_male_ptts_feats目录下得到如下结构的文件:
```
├── am_train.lst
├── am_valid.lst
├── audio_config.yaml
├── badlist.txt
├── data_process_stdout.log
├── duration
├── energy
├── f0
├── frame_energy
├── frame_f0
├── frame_uv
├── mel
├── raw_duration
├── raw_metafile.txt
├── Script.xml
├── se
├── train.lst
├── valid.lst
└── wav
```
## 训练 ## 训练
#### 单卡训练 #### 单卡训练
...@@ -58,7 +126,7 @@ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/train_sambert.py \ ...@@ -58,7 +126,7 @@ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/train_sambert.py \
--stage_dir training_stage/ptts_sambert_ckpt \ --stage_dir training_stage/ptts_sambert_ckpt \
--resume_path speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_*.pth --resume_path speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_*.pth
``` ```
注意修改相应的模型路径。
#### 单卡推理 #### 单卡推理
``` ```
HIP_VISIBLE_DEVICES=0 python3 kantts/bin/text_to_wav.py \ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/text_to_wav.py \
...@@ -69,6 +137,7 @@ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/text_to_wav.py \ ...@@ -69,6 +137,7 @@ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/text_to_wav.py \
--voc_ckpt speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth \ --voc_ckpt speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth \
--se_file training_stage/ptts_feats/se/se.npy --se_file training_stage/ptts_feats/se/se.npy
``` ```
注意修改相应的模型路径。
## result ## result
可在输出文件夹res/ptts_syn下找到克隆的语音文件。 可在输出文件夹res/ptts_syn下找到克隆的语音文件。
...@@ -80,6 +149,11 @@ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/text_to_wav.py \ ...@@ -80,6 +149,11 @@ HIP_VISIBLE_DEVICES=0 python3 kantts/bin/text_to_wav.py \
### 热点应用行业 ### 热点应用行业
制造,广媒,能源,医疗,家居,教育 制造,广媒,能源,医疗,家居,教育
### 预训练权重
```
git clone https://www.modelscope.cn/damo/speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k.git
```
## 源码仓库及问题反馈 ## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/sambert-hifigan_pytorch https://developer.hpccube.com/codes/modelzoo/sambert-hifigan_pytorch
......
import os
import sys
import argparse
import torch
import soundfile as sf
import yaml
import logging
import numpy as np
import time
import glob
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.utils.log import logging_to_file
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def load_model(ckpt, config=None):
# load config if not provided
if config is None:
dirname = os.path.dirname(os.path.dirname(ckpt))
config = os.path.join(dirname, "config.yaml")
with open(config) as f:
config = yaml.load(f, Loader=yaml.Loader)
# lazy load for circular error
from kantts.models.hifigan.hifigan import Generator
model = Generator(**config["Model"]["Generator"]["params"])
states = torch.load(ckpt, map_location="cpu")
model.load_state_dict(states["model"]["generator"])
# add pqmf if needed
if config["Model"]["Generator"]["params"]["out_channels"] > 1:
# lazy load for circular error
from kantts.models.pqmf import PQMF
model.pqmf = PQMF()
return model
def binarize(mel, threshold=0.6):
# vuv binarize
res_mel = mel.copy()
index = np.where(mel[:, -1] < threshold)[0]
res_mel[:, -1] = 1.0
res_mel[:, -1][index] = 0.0
return res_mel
def hifigan_infer(input_mel, ckpt_path, output_dir, config=None):
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda", 0)
if config is not None:
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
else:
config_path = os.path.join(
os.path.dirname(os.path.dirname(ckpt_path)), "config.yaml"
)
if not os.path.exists(config_path):
raise ValueError("config file not found: {}".format(config_path))
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
for key, value in config.items():
logging.info(f"{key} = {value}")
# check directory existence
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logging_to_file(os.path.join(output_dir, "stdout.log"))
if os.path.isfile(input_mel):
mel_lst = [input_mel]
elif os.path.isdir(input_mel):
mel_lst = glob.glob(os.path.join(input_mel, "*.npy"))
else:
raise ValueError("input_mel should be a file or a directory")
model = load_model(ckpt_path, config)
logging.info(f"Loaded model parameters from {ckpt_path}.")
model.remove_weight_norm()
model = model.eval().to(device)
with torch.no_grad():
start = time.time()
pcm_len = 0
for mel in mel_lst:
utt_id = os.path.splitext(os.path.basename(mel))[0]
mel_data = np.load(mel)
if model.nsf_enable:
mel_data = binarize(mel_data)
# generate
mel_data = torch.tensor(mel_data, dtype=torch.float).to(device)
# (T, C) -> (B, C, T)
mel_data = mel_data.transpose(1, 0).unsqueeze(0)
y = model(mel_data)
if hasattr(model, "pqmf"):
y = model.pqmf.synthesis(y)
y = y.view(-1).cpu().numpy()
pcm_len += len(y)
# save as PCM 16 bit wav file
sf.write(
os.path.join(output_dir, f"{utt_id}_gen.wav"),
y,
config["audio_config"]["sampling_rate"],
"PCM_16",
)
rtf = (time.time() - start) / (
pcm_len / config["audio_config"]["sampling_rate"]
)
# report average RTF
logging.info(
f"Finished generation of {len(mel_lst)} utterances (RTF = {rtf:.03f})."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Infer hifigan model")
parser.add_argument(
"--ckpt", type=str, required=True, help="Path to model checkpoint"
)
parser.add_argument(
"--input_mel",
type=str,
required=True,
help="Path to input mel file or directory containing mel files",
)
parser.add_argument(
"--output_dir", type=str, required=True, help="Path to output directory"
)
parser.add_argument("--config", type=str, default=None, help="Path to config file")
args = parser.parse_args()
hifigan_infer(
args.input_mel,
args.ckpt,
args.output_dir,
args.config,
)
import os
import sys
import argparse
import torch
import soundfile as sf
import yaml
import logging
import numpy as np
import time
import glob
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.utils.log import logging_to_file
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def load_model(ckpt, config=None):
# load config if not provided
if config is None:
dirname = os.path.dirname(os.path.dirname(ckpt))
config = os.path.join(dirname, "config.yaml")
with open(config) as f:
config = yaml.load(f, Loader=yaml.Loader)
# lazy load for circular error
from kantts.models.hifigan.hifigan import Generator
model = Generator(**config["Model"]["Generator"]["params"])
states = torch.load(ckpt, map_location="cpu")
model.load_state_dict(states["model"]["generator"])
# add pqmf if needed
if config["Model"]["Generator"]["params"]["out_channels"] > 1:
# lazy load for circular error
from kantts.models.pqmf import PQMF
model.pqmf = PQMF()
return model
def binarize(mel, threshold=0.6):
# vuv binarize
res_mel = mel.copy()
index = np.where(mel[:, -1] < threshold)[0]
res_mel[:, -1] = 1.0
res_mel[:, -1][index] = 0.0
return res_mel
def hifigan_infer(input_mel, ckpt, output_dir, config=None):
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda", 0)
# device = torch.device("cpu")
if config is not None:
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
else:
config_path = os.path.join(
os.path.dirname(os.path.dirname(ckpt)), "config.yaml"
)
if not os.path.exists(config_path):
raise ValueError("config file not found: {}".format(config_path))
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
for key, value in config.items():
logging.info(f"{key} = {value}")
# check directory existence
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logging_to_file(os.path.join(output_dir, "stdout.log"))
if os.path.isfile(input_mel):
mel_lst = [input_mel]
elif os.path.isdir(input_mel):
mel_lst = glob.glob(os.path.join(input_mel, "*.npy"))
else:
raise ValueError("input_mel should be a file or a directory")
model = load_model(ckpt, config)
logging.info(f"Loaded model parameters from {ckpt}.")
model.remove_weight_norm()
model = model.eval().to(device)
with torch.no_grad():
start = time.time()
pcm_len = 0
# i = 0 # 转onnx控制模型运行一次
for mel in mel_lst:
# if i > 0:
# break
# i = i+1
start1 = time.time()
utt_id = os.path.splitext(os.path.basename(mel))[0]
logging.info("Inference sentence: {}".format(utt_id))
mel_data = np.load(mel)
if model.nsf_enable:
mel_data = binarize(mel_data)
# generate
mel_data = torch.tensor(mel_data, dtype=torch.float).to(device)
# (T, C) -> (B, C, T)
mel_data = mel_data.transpose(1, 0).unsqueeze(0)
# # GPU预热
# for _ in range(10):
# _ = model(mel_data)
# # 测速
# iterations = 100
# times = torch.zeros(iterations) # 存储每轮iteration的时间
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# for iter in range(iterations):
# starter.record()
# y = model(mel_data)
# ender.record()
# # 同步GPU时间
# torch.cuda.synchronize()
# cur_time = starter.elapsed_time(ender) # 计算时间
# times[iter] = cur_time
# mean_time = times.mean().item()
# print("hifigan-pytorch infer single time: {:.6f} ms".format(mean_time))
y = model(mel_data)
# PyTorch模型转换成 ONNX 格式
x0 = mel_data
dynamic_axes = {
'mel_data': {2: 'hifigan_input_dim2'},
'y': {2: 'hifigan_output_dim2'},}
torch.onnx.export(
model,
x0,
"hifigan_onnx/hifigan.onnx",
opset_version=13,
input_names=['mel_data'],
output_names=['y'],
dynamic_axes=dynamic_axes
)
if hasattr(model, "pqmf"):
y = model.pqmf.synthesis(y)
y = y.view(-1).cpu().numpy()
pcm_len += len(y)
# save as PCM 16 bit wav file
sf.write(
os.path.join(output_dir, f"{utt_id}_gen.wav"),
y,
config["audio_config"]["sampling_rate"],
"PCM_16",
)
total_elapsed = time.time() - start1
print(f'Vocoder infer single time: {total_elapsed} seconds')
rtf = (time.time() - start) / (
pcm_len / config["audio_config"]["sampling_rate"]
)
# report average RTF
logging.info(
f"Finished generation of {len(mel_lst)} utterances (RTF = {rtf:.03f})."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Infer hifigan model")
parser.add_argument(
"--ckpt", type=str, required=True, help="Path to model checkpoint"
)
parser.add_argument(
"--input_mel",
type=str,
required=True,
help="Path to input mel file or directory containing mel files",
)
parser.add_argument(
"--output_dir", type=str, required=True, help="Path to output directory"
)
parser.add_argument("--config", type=str, default=None, help="Path to config file")
args = parser.parse_args()
hifigan_infer(
args.input_mel,
args.ckpt,
args.output_dir,
args.config,
)
import sys
import torch
import os
import numpy as np
import argparse
import yaml
import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.models import model_builder
from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_feature=None):
if norm_type == 'mean_std':
f0_mvn = f0_feature
f0 = mel[:, -2]
uv = mel[:, -1]
uv[uv < uv_threshold] = 0.0
uv[uv >= uv_threshold] = 1.0
f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :]
f0[f0 < f0_threshold] = f0_threshold
mel[:, -2] = f0
mel[:, -1] = uv
else: # global
f0_global_max_min = f0_feature
f0 = mel[:, -2]
uv = mel[:, -1]
uv[uv < uv_threshold] = 0.0
uv[uv >= uv_threshold] = 1.0
f0 = f0 * (f0_global_max_min[0] - f0_global_max_min[1]) + f0_global_max_min[1]
f0[f0 < f0_threshold] = f0_threshold
mel[:, -2] = f0
mel[:, -1] = uv
return mel
def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None):
inputs_feat_lst = ling_unit.encode_symbol_sequence(symbol_seq)
inputs_feat_index = 0
if ling_unit.using_byte():
inputs_byte_index = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_ling = torch.stack([inputs_byte_index], dim=-1).unsqueeze(0)
else:
inputs_sy = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_feat_index = inputs_feat_index + 1
inputs_tone = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_feat_index = inputs_feat_index + 1
inputs_syllable = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_feat_index = inputs_feat_index + 1
inputs_ws = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
)
inputs_ling = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], dim=-1
).unsqueeze(0)
inputs_feat_index = inputs_feat_index + 1
inputs_emo = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index])
.long()
.to(device)
.unsqueeze(0)
)
inputs_feat_index = inputs_feat_index + 1
se_enable = False if se is None else True
if se_enable:
inputs_spk = (
torch.from_numpy(se.repeat(len(inputs_feat_lst[inputs_feat_index]), axis=0))
.float()
.to(device)
.unsqueeze(0)[:, :-1, :]
)
else:
inputs_spk = (
torch.from_numpy(inputs_feat_lst[inputs_feat_index])
.long()
.to(device)
.unsqueeze(0)[:, :-1]
)
inputs_len = (
torch.zeros(1).to(device).long() + inputs_emo.size(1) - 1
) # minus 1 for "~"
res = fsnet(
inputs_ling[:, :-1, :],
inputs_emo[:, :-1],
inputs_spk,
inputs_len,
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
# enc_slf_attn_lst = res["enc_slf_attn_lst"]
# pnca_x_attn_lst = res["pnca_x_attn_lst"]
# pnca_h_attn_lst = res["pnca_h_attn_lst"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
LR_length_rounded = res["LR_length_rounded"]
log_duration_predictions = res["log_duration_predictions"]
pitch_predictions = res["pitch_predictions"]
energy_predictions = res["energy_predictions"]
valid_length = int(LR_length_rounded[0].item())
dec_outputs = dec_outputs[0, :valid_length, :].cpu().numpy()
postnet_outputs = postnet_outputs[0, :valid_length, :].cpu().numpy()
duration_predictions = (
(torch.exp(log_duration_predictions) - 1 + 0.5).long().squeeze().cpu().numpy()
)
pitch_predictions = pitch_predictions.squeeze().cpu().numpy()
energy_predictions = energy_predictions.squeeze().cpu().numpy()
logging.info("x_band_width:{}, h_band_width: {}".format(x_band_width, h_band_width))
return (
dec_outputs,
postnet_outputs,
duration_predictions,
pitch_predictions,
energy_predictions,
)
def am_infer(sentence, ckpt, output_dir, se_file=None, config=None):
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda", 0)
if config is not None:
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
else:
am_config_file = os.path.join(
os.path.dirname(os.path.dirname(ckpt)), "config.yaml"
)
with open(am_config_file, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
ling_unit = KanTtsLinguisticUnit(config)
ling_unit_size = ling_unit.get_unit_size()
config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size)
se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False)
se = np.load(se_file) if se_enable else None
# nsf
nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False)
if nsf_enable:
nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_norm_type", "mean_std")
if nsf_norm_type == "mean_std":
f0_mvn_file = os.path.join(
os.path.dirname(os.path.dirname(ckpt)), "mvn.npy"
)
f0_feature = np.load(f0_mvn_file)
else: # global
nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_minimum", 30.0)
nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_maximum", 730.0)
f0_feature = [nsf_f0_global_maximum, nsf_f0_global_minimum]
model, _, _ = model_builder(config, device)
fsnet = model["KanTtsSAMBERT"]
logging.info("Loading checkpoint: {}".format(ckpt))
state_dict = torch.load(ckpt)
fsnet.load_state_dict(state_dict["model"], strict=False)
results_dir = os.path.join(output_dir, "feat")
os.makedirs(results_dir, exist_ok=True)
fsnet.eval()
with open(sentence, encoding="utf-8") as f:
for line in f:
line = line.strip().split("\t")
logging.info("Inference sentence: {}".format(line[0]))
mel_path = "%s/%s_mel.npy" % (results_dir, line[0])
dur_path = "%s/%s_dur.txt" % (results_dir, line[0])
f0_path = "%s/%s_f0.txt" % (results_dir, line[0])
energy_path = "%s/%s_energy.txt" % (results_dir, line[0])
with torch.no_grad():
mel, mel_post, dur, f0, energy = am_synthesis(
line[1], fsnet, ling_unit, device, se=se
)
if nsf_enable:
mel_post = denorm_f0(mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature)
np.save(mel_path, mel_post)
np.savetxt(dur_path, dur)
np.savetxt(f0_path, f0)
np.savetxt(energy_path, energy)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sentence", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--se_file", type=str, required=False)
args = parser.parse_args()
am_infer(args.sentence, args.ckpt, args.output_dir, args.se_file)
This diff is collapsed.
import os
import sys
import argparse
import yaml
import logging
import zipfile
from glob import glob
import soundfile as sf
import numpy as np
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.bin.infer_sambert import am_infer
from kantts.bin.infer_hifigan import hifigan_infer
from kantts.utils.ling_unit import text_to_mit_symbols as text_to_symbols
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def concat_process(chunked_dir, output_dir):
wav_files = sorted(glob(os.path.join(chunked_dir, "*.wav")))
print(wav_files)
sentence_sil = 0.28 # seconds
end_sil = 0.05 # seconds
cnt = 0
wav_concat = None
main_id, sub_id = 0, 0
while cnt < len(wav_files):
wav_file = os.path.join(
chunked_dir, "{}_{}_mel_gen.wav".format(main_id, sub_id)
)
if os.path.exists(wav_file):
wav, sr = sf.read(wav_file)
sentence_sil_samples = int(sentence_sil * sr)
end_sil_samples = int(end_sil * sr)
if sub_id == 0:
wav_concat = wav
else:
wav_concat = np.concatenate(
(wav_concat, np.zeros(sentence_sil_samples), wav), axis=0
)
sub_id += 1
cnt += 1
else:
if wav_concat is not None:
wav_concat = np.concatenate(
(wav_concat, np.zeros(end_sil_samples)), axis=0
)
sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr)
main_id += 1
sub_id = 0
wav_concat = None
if cnt == len(wav_files):
wav_concat = np.concatenate((wav_concat, np.zeros(end_sil_samples)), axis=0)
sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr)
def text_to_wav(
text_file,
output_dir,
resources_zip_file,
am_ckpt,
voc_ckpt,
speaker=None,
se_file=None,
lang="PinYin",
):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "res_wavs"), exist_ok=True)
resource_root_dir = os.path.dirname(resources_zip_file)
resource_dir = os.path.join(resource_root_dir, "resource")
if not os.path.exists(resource_dir):
logging.info("Extracting resources...")
with zipfile.ZipFile(resources_zip_file, "r") as zip_ref:
zip_ref.extractall(resource_root_dir)
with open(text_file, "r") as text_data:
texts = text_data.readlines()
logging.info("Converting text to symbols...")
am_config = os.path.join(os.path.dirname(os.path.dirname(am_ckpt)), "config.yaml")
with open(am_config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
if speaker is None:
speaker = config["linguistic_unit"]["speaker_list"].split(",")[0]
symbols_lst = text_to_symbols(texts, resource_dir, speaker, lang)
symbols_file = os.path.join(output_dir, "symbols.lst")
with open(symbols_file, "w") as symbol_data:
for symbol in symbols_lst:
symbol_data.write(symbol)
logging.info("AM is infering...")
am_infer(symbols_file, am_ckpt, output_dir, se_file)
logging.info("Vocoder is infering...")
hifigan_infer(os.path.join(output_dir, "feat"), voc_ckpt, output_dir)
concat_process(output_dir, os.path.join(output_dir, "res_wavs"))
logging.info("Text to wav finished!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Text to wav")
parser.add_argument("--txt", type=str, required=True, help="Path to text file")
parser.add_argument(
"--output_dir", type=str, required=True, help="Path to output directory"
)
parser.add_argument(
"--res_zip", type=str, required=True, help="Path to resource zip file"
)
parser.add_argument(
"--am_ckpt", type=str, required=True, help="Path to am ckpt file"
)
parser.add_argument(
"--voc_ckpt", type=str, required=True, help="Path to voc ckpt file"
)
parser.add_argument(
"--speaker",
type=str,
required=False,
default=None,
help="The speaker name, default is the first speaker",
)
parser.add_argument(
"--se_file",
type=str,
required=False,
default=None,
help="The speaker embedding file , default is None",
)
parser.add_argument(
"--lang",
type=str,
default="PinYin",
help="""The language of the text, default is PinYin, other options are:
English,
British,
ZhHK,
WuuShanghai,
Sichuan,
Indonesian,
Malay,
Filipino,
Vietnamese,
Korean,
Russian
""",
)
args = parser.parse_args()
text_to_wav(
args.txt,
args.output_dir,
args.res_zip,
args.am_ckpt,
args.voc_ckpt,
args.speaker,
args.se_file,
args.lang,
)
import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
import logging
import time
import yaml
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.models import model_builder
from kantts.train.loss import criterion_builder
from kantts.datasets.dataset import get_voc_datasets
from kantts.train.trainer import GAN_Trainer, distributed_init
from kantts.utils.log import logging_to_file, get_git_revision_hash
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# TODO: distributed training
def train(
model_config,
root_dir,
stage_dir,
resume_path=None,
local_rank=0,
):
if not torch.cuda.is_available():
device = torch.device("cpu")
distributed = False
else:
torch.backends.cudnn.benchmark = True
logging.info("Args local rank: {}".format(local_rank))
distributed, device, local_rank, world_size = distributed_init()
if local_rank != 0:
sys.stdout = open(os.devnull, "w")
logger = logging.getLogger()
logger.disabled = True
# TODO: make sure all root dir audio_configs are the same
if not isinstance(root_dir, list):
root_dir = [root_dir]
if len(root_dir) == 1 and os.path.isfile(root_dir[0]):
with open(root_dir[0], "r") as f:
dir_lines = f.readlines()
root_dir = [line.strip() for line in dir_lines]
if local_rank == 0 and not os.path.exists(stage_dir):
os.makedirs(stage_dir)
audio_config = os.path.join(root_dir[0], "audio_config.yaml")
with open(audio_config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
with open(model_config, "r") as f:
config.update(yaml.load(f, Loader=yaml.Loader))
logging_to_file(os.path.join(stage_dir, "stdout.log"))
# TODO: record some info in config, such as create time, git commit revision
config["create_time"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
config["git_revision_hash"] = get_git_revision_hash()
with open(os.path.join(stage_dir, "config.yaml"), "w") as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
for key, value in config.items():
logging.info(f"{key} = {value}")
if distributed:
config["rank"] = torch.distributed.get_rank()
config["distributed"] = True
# TODO: abstract dataloader
# Dataset prepare
train_dataset, valid_dataset = get_voc_datasets(
config,
root_dir,
)
logging.info(f"The number of training files = {len(train_dataset)}.")
logging.info(f"The number of validation files = {len(valid_dataset)}.")
sampler = {"train": None, "valid": None}
if distributed:
# setup sampler for distributed training
from torch.utils.data.distributed import DistributedSampler
sampler["train"] = DistributedSampler(
dataset=train_dataset,
num_replicas=world_size,
shuffle=True,
)
sampler["valid"] = DistributedSampler(
dataset=valid_dataset,
num_replicas=world_size,
shuffle=False,
)
train_dataloader = DataLoader(
train_dataset,
shuffle=False if distributed else True,
collate_fn=train_dataset.collate_fn,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
sampler=sampler["train"],
pin_memory=config["pin_memory"],
)
valid_dataloader = DataLoader(
valid_dataset,
shuffle=False if distributed else True,
collate_fn=valid_dataset.collate_fn,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
sampler=sampler["valid"],
pin_memory=config["pin_memory"],
)
model, optimizer, scheduler = model_builder(config, device, local_rank, distributed)
criterion = criterion_builder(config, device)
logging.info(model["generator"])
logging.info(
"Generator mdoel parameters count: {}".format(
count_parameters(model["generator"])
)
)
logging.info(model["discriminator"])
logging.info(optimizer["generator"])
logging.info(optimizer["discriminator"])
logging.info(scheduler["generator"])
logging.info(scheduler["discriminator"])
for criterion_ in criterion.values():
logging.info(criterion_)
trainer = GAN_Trainer(
config=config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
max_steps=config["train_max_steps"],
save_dir=stage_dir,
save_interval=config["save_interval_steps"],
valid_interval=config["eval_interval_steps"],
log_interval=config["log_interval_steps"],
)
if resume_path is not None:
trainer.load_checkpoint(resume_path, restore_training_state=False)
logging.info(f"Successfully resumed from {resume_path}.")
try:
trainer.train()
except (Exception, KeyboardInterrupt) as e:
logging.error(e, exc_info=True)
trainer.save_checkpoint(
os.path.join(
os.path.join(stage_dir, "ckpt"), f"checkpoint-{trainer.steps}.pth"
)
)
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a model for speech synthesis")
parser.add_argument(
"--model_config", type=str, required=True, help="model config file"
)
parser.add_argument(
"--root_dir",
nargs="+",
type=str,
required=True,
help="root dir of dataset; cloud be multiple directories",
)
parser.add_argument(
"--stage_dir",
type=str,
required=True,
help="stage dir of checkpoint, log and intermidate results ",
)
parser.add_argument(
"--resume_path", type=str, default=None, help="path to resume checkpoint"
)
parser.add_argument(
"--local_rank", type=int, default=0, help="local rank for distributed training"
)
args = parser.parse_args()
train(
args.model_config,
args.root_dir,
args.stage_dir,
args.resume_path,
args.local_rank,
)
import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
import logging
import time
import yaml
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.models import model_builder
from kantts.train.loss import criterion_builder
from kantts.datasets.dataset import get_am_datasets
from kantts.train.trainer import Sambert_Trainer, distributed_init
from kantts.utils.log import logging_to_file, get_git_revision_hash
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# TODO: distributed training
def train(
model_config,
root_dir,
stage_dir,
resume_path=None,
resume_bert_path=None,
local_rank=0,
):
if not torch.cuda.is_available():
device = torch.device("cpu")
distributed = False
else:
torch.backends.cudnn.benchmark = True
logging.info("Args local rank: {}".format(local_rank))
distributed, device, local_rank, world_size = distributed_init()
if local_rank != 0:
sys.stdout = open(os.devnull, "w")
logger = logging.getLogger()
logger.disabled = True
if not isinstance(root_dir, list):
root_dir = [root_dir]
if len(root_dir) == 1 and os.path.isfile(root_dir[0]):
with open(root_dir[0], "r") as f:
dir_lines = f.readlines()
root_dir = [line.strip() for line in dir_lines]
if local_rank == 0 and not os.path.exists(stage_dir): # why local_rank == 0 and?if local_rank==7?
os.makedirs(stage_dir)
audio_config = os.path.join(root_dir[0], "audio_config.yaml")
with open(audio_config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
with open(model_config, "r") as f:
config.update(yaml.load(f, Loader=yaml.Loader))
logging_to_file(os.path.join(stage_dir, "stdout.log"))
# TODO: record some info in config, such as create time, git commit revision
config["create_time"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
config["git_revision_hash"] = get_git_revision_hash()
with open(os.path.join(stage_dir, "config.yaml"), "w") as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
for key, value in config.items():
logging.info(f"{key} = {value}")
if distributed:
config["rank"] = torch.distributed.get_rank()
config["distributed"] = True
se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False)
if se_enable:
valid_enable = False
valid_split_ratio = 0.00
else:
valid_enable = True
valid_split_ratio = 0.02
fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
meta_file = [
os.path.join(d, "raw_metafile.txt" if not fp_enable else "fprm_metafile.txt")
for d in root_dir
]
# TODO: abstract dataloader
# Dataset prepare
train_dataset, valid_dataset = get_am_datasets(
meta_file,
root_dir,
config,
config["allow_cache"],
split_ratio=1.0 - valid_split_ratio,
)
logging.info(f"The number of training files = {len(train_dataset)}.")
logging.info(f"The number of validation files = {len(valid_dataset)}.")
sampler = {"train": None, "valid": None}
if distributed:
# setup sampler for distributed training
from torch.utils.data.distributed import DistributedSampler
sampler["train"] = DistributedSampler(
dataset=train_dataset,
num_replicas=world_size,
shuffle=True,
)
sampler["valid"] = (
DistributedSampler(
dataset=valid_dataset,
num_replicas=world_size,
shuffle=False,
)
if valid_enable
else None
)
train_dataloader = DataLoader(
train_dataset,
shuffle=False if distributed else True,
collate_fn=train_dataset.collate_fn,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
sampler=sampler["train"],
pin_memory=config["pin_memory"],
)
valid_dataloader = (
DataLoader(
valid_dataset,
shuffle=False if distributed else True,
collate_fn=valid_dataset.collate_fn,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
sampler=sampler["valid"],
pin_memory=config["pin_memory"],
)
if valid_enable
else None
)
ling_unit_size = train_dataset.ling_unit.get_unit_size()
config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size)
model, optimizer, scheduler = model_builder(config, device, local_rank, distributed)
criterion = criterion_builder(config, device)
logging.info(model["KanTtsSAMBERT"])
logging.info(
"Sambert mdoel parameters count: {}".format(
count_parameters(model["KanTtsSAMBERT"])
)
)
logging.info(optimizer["KanTtsSAMBERT"])
logging.info(scheduler["KanTtsSAMBERT"])
for criterion_ in criterion.values():
logging.info(criterion_)
trainer = Sambert_Trainer(
config=config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
max_steps=config["train_max_steps"],
save_dir=stage_dir,
save_interval=config["save_interval_steps"],
valid_interval=config["eval_interval_steps"],
log_interval=config["log_interval_steps"],
grad_clip=config["grad_norm"],
)
if resume_path is not None:
trainer.load_checkpoint(resume_path, True, False)
logging.info(f"Successfully resumed from {resume_path}.")
if resume_bert_path is not None:
trainer.load_checkpoint(resume_bert_path, False, False)
logging.info(f"Successfully resumed from {resume_bert_path}.")
try:
trainer.train()
except (Exception, KeyboardInterrupt) as e:
logging.error(e, exc_info=True)
trainer.save_checkpoint(
os.path.join(
os.path.join(stage_dir, "ckpt"), f"checkpoint-{trainer.steps}.pth"
)
)
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a model for speech synthesis")
parser.add_argument(
"--model_config", type=str, required=True, help="model config file"
)
parser.add_argument(
"--root_dir",
nargs="+",
type=str,
required=True,
help="root dir of dataset; cloud be multiple directories",
)
parser.add_argument(
"--stage_dir",
type=str,
required=True,
help="stage dir of checkpoint, log and intermidate results ",
)
parser.add_argument(
"--resume_path", type=str, default=None, help="path to resume checkpoint"
)
parser.add_argument(
"--resume_bert_path",
type=str,
default=None,
help="path to resume pretrained-bert checkpoint",
)
parser.add_argument(
"--local_rank", type=int, default=0, help="local rank for distributed training"
)
args = parser.parse_args()
train(
args.model_config,
args.root_dir,
args.stage_dir,
args.resume_path,
args.resume_bert_path,
args.local_rank,
)
import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
import logging
import time
import yaml
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.models import model_builder
from kantts.train.loss import criterion_builder
from kantts.datasets.dataset import get_bert_text_datasets
from kantts.train.trainer import distributed_init, Textsy_BERT_Trainer
from kantts.utils.log import logging_to_file, get_git_revision_hash
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
# filename=os.path.join(stage_dir, 'stdout.log'),
format="%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# TODO: distributed training
def train(
model_config,
root_dir,
stage_dir,
resume_path=None,
local_rank=0,
):
if not torch.cuda.is_available():
device = torch.device("cpu")
distributed = False
else:
torch.backends.cudnn.benchmark = True
logging.info("Args local rank: {}".format(local_rank))
distributed, device, local_rank, world_size = distributed_init()
if local_rank != 0:
sys.stdout = open(os.devnull, "w")
logger = logging.getLogger()
logger.disabled = True
if not isinstance(root_dir, list):
root_dir = [root_dir]
meta_file = [os.path.join(d, "raw_metafile.txt") for d in root_dir]
if local_rank == 0 and not os.path.exists(stage_dir):
os.makedirs(stage_dir)
with open(model_config, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
logging_to_file(os.path.join(stage_dir, "stdout.log"))
# TODO: record some info in config, such as create time, git commit
# revision
config["create_time"] = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime())
config["git_revision_hash"] = get_git_revision_hash()
with open(os.path.join(stage_dir, "config.yaml"), "w") as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
for key, value in config.items():
logging.info(f"{key} = {value}")
if distributed:
config["rank"] = torch.distributed.get_rank()
config["distributed"] = True
# TODO: abstract dataloader
# Dataset prepare
train_dataset, valid_dataset = get_bert_text_datasets(
meta_file, root_dir, config, config["allow_cache"]
)
logging.info(f"The number of training files = {len(train_dataset)}.")
logging.info(f"The number of validation files = {len(valid_dataset)}.")
sampler = {"train": None, "valid": None}
if distributed:
# setup sampler for distributed training
from torch.utils.data.distributed import DistributedSampler
sampler["train"] = DistributedSampler(
dataset=train_dataset,
num_replicas=world_size,
shuffle=True,
)
sampler["valid"] = DistributedSampler(
dataset=valid_dataset,
num_replicas=world_size,
shuffle=False,
)
train_dataloader = DataLoader(
train_dataset,
shuffle=False if distributed else True,
collate_fn=train_dataset.collate_fn,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
sampler=sampler["train"],
pin_memory=config["pin_memory"],
)
valid_dataloader = DataLoader(
valid_dataset,
shuffle=False if distributed else True,
collate_fn=valid_dataset.collate_fn,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
sampler=sampler["valid"],
pin_memory=config["pin_memory"],
)
ling_unit_size = train_dataset.ling_unit.get_unit_size()
config["Model"]["KanTtsTextsyBERT"]["params"].update(ling_unit_size)
model, optimizer, scheduler = model_builder(
config, device, local_rank, distributed)
criterion = criterion_builder(config, device)
logging.info(model["KanTtsTextsyBERT"])
logging.info(
"TextsyBERT mdoel parameters count: {}".format(
count_parameters(model["KanTtsTextsyBERT"])
)
)
logging.info(optimizer["KanTtsTextsyBERT"])
logging.info(scheduler["KanTtsTextsyBERT"])
for criterion_ in criterion.values():
logging.info(criterion_)
trainer = Textsy_BERT_Trainer(
config=config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
max_steps=config["train_max_steps"],
save_dir=stage_dir,
save_interval=config["save_interval_steps"],
valid_interval=config["eval_interval_steps"],
log_interval=config["log_interval_steps"],
grad_clip=config["grad_norm"],
)
if resume_path is not None:
trainer.load_checkpoint(resume_path, True)
logging.info(f"Successfully resumed from {resume_path}.")
try:
trainer.train()
except (Exception, KeyboardInterrupt) as e:
logging.error(e, exc_info=True)
trainer.save_checkpoint(
os.path.join(
os.path.join(
stage_dir,
"ckpt"),
f"checkpoint-{trainer.steps}.pth"))
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a model for speech synthesis")
parser.add_argument(
"--model_config", type=str, required=True, help="model config file"
)
parser.add_argument(
"--root_dir",
nargs="+",
type=str,
required=True,
help="root dir of dataset; cloud be multiple directories",
)
parser.add_argument(
"--stage_dir",
type=str,
required=True,
help="stage dir of checkpoint, log and intermidate results ",
)
parser.add_argument(
"--resume_path",
type=str,
default=None,
help="path to resume checkpoint")
parser.add_argument(
"--local_rank",
type=int,
default=0,
help="local rank for distributed training")
args = parser.parse_args()
train(
args.model_config,
args.root_dir,
args.stage_dir,
args.resume_path,
args.local_rank,
)
# Audio processing configs
audio_config:
# Preprocess
wav_normalize: True
trim_silence: True
trim_silence_threshold_db: 60
preemphasize: False
# Feature extraction
sampling_rate: 16000
hop_length: 200
win_length: 1000
n_fft: 2048
n_mels: 80
fmin: 0.0
fmax: 8000.0
phone_level_feature: True
# Normalization
norm_type: "mean_std" # "mean_std" or "global"
max_norm: 1.0
symmetric: False
min_level_db: -100.0
ref_level_db: 20
num_workers: 16
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment