Commit b309ea1b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #957 canceled with stages
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
MIT License
Copyright (c) 2024 KdaiP
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# StableTTS
StableTTS是一款用于中英文语音生成的快速轻量级TTS模型,只有10M参数。
## 论文
`未发表论文`
## 模型结构
受Stable Diffusion 3的启发,将流匹配和DiT相结合成开源TTS模型。
<div align=center>
<img src="./doc/structure.png"/>
</div>
## 算法原理
Hierspeech++的扩散卷积转换器模块是原始 DiT和FFT的组合,以获得更好的韵律,流匹配解码器中,在DiT模块之前添加一个FiLM层,以条件时间步长嵌入到模型中。
<div align=center>
<img src="./doc/algorithm.png"/>
</div>
## 环境配置
```
mv stabletts_pytorch StableTTS # 去框架名后缀
```
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10-py38
# <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:ffa1f63239fc
docker run -it --shm-size=32G -v $PWD/StableTTS:/home/StableTTS -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name stabletts <your IMAGE ID> bash
cd /home/StableTTS
pip install -r requirements.txt # requirements.txt
# torchaudio可从whl.zip文件里获取安装:
pip install torchaudio-2.1.2+4b32183.abi0.dtk2310.torch2.1.0a0-cp38-cp38-linux_x86_64.whl
```
### Dockerfile(方法二)
```
cd StableTTS/docker
docker build --no-cache -t stabletts:latest .
docker run --shm-size=32G --name stabletts -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v $PWD/../../StableTTS:/home/StableTTS -it stabletts bash
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r requirements.txt。
# torchaudio可从whl.zip文件里获取安装:
pip install torchaudio-2.1.2+4b32183.abi0.dtk2310.torch2.1.0a0-cp38-cp38-linux_x86_64.whl
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
- https://developer.hpccube.com/tool/
```
DTK驱动:dtk23.10
python:python3.8
torch:2.1.0
torchvision:0.16.0
torchaudio:2.1.2
```
```
# torchaudio可从whl.zip文件里获取安装:
pip install torchaudio-2.1.2+4b32183.abi0.dtk2310.torch2.1.0a0-cp38-cp38-linux_x86_64.whl
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应。`
2、其它非特殊库参照requirements.txt安装
```
pip install -r requirements.txt # requirements.txt
```
## 数据集
本步骤说明采用标贝女声数据集`BZNSYP`,其余音色数据参照[`recipes`](./recipes/)下的文件说明进行下载使用,项目中已提供[`BZNSYP`](./recipes/raw_datasets/BZNSYP.zip)迷你数据集进行试用,解压即可,完整BZNSYP数据集请从以下官网下载:
- https://www.data-baker.com/data/index/TNtts/
数据目录结构如下:
```
recipes/raw_datasets/BZNSYP
├── Wave
├── ├── xxx.wav
├── └── xxx.wav
└── PhoneLabeling
├── ├── xxx.interval
├── └── xxx.interval
└── ProsodyLabeling
├── └── 000001-010000.txt
```
数据预处理命令为:
```
cd recipes
python BZNSYP_标贝女声.py
mv filelists/bznsyp.txt ../filelists/filelist.txt
cd ..
python preprocess.py # 生成训练需要用的filelists/filelist.json与stableTTS_datasets/mels
```
## 训练
### 单机单卡
```
export HIP_VISIBLE_DEVICES=0
cd StableTTS
python train.py
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)
## 推理
100h chinese: `checkpoint-zh_0.pt`
- https://huggingface.co/KdaiP/StableTTS/blob/main/checkpoint-zh_0.pt
2k english + chinese + japanese: `vocoder.pt`
- https://huggingface.co/KdaiP/StableTTS/blob/main/vocoder.pt
```
export HIP_VISIBLE_DEVICES=0
mv vocoder.pt ./checkpoints/vocoder.pt
python inference.py
# 使用默认权重:
# tts_checkpoint_path = './checkpoints/checkpoint-zh_0.pt'
# vocoder_checkpoint_path = './checkpoints/vocoder.pt'。
```
## result
`输入:`
```
'你好,世界!' # 文本
'./audio.wav' # 音色
```
`输出:`
```
'generate.wav' # 合成声音
```
### 精度
max epoch为1000,推理框架:pytorch。
| device | Loss |
|:---------:|:------:|
| DCU Z100L | 1.9369 |
| GPU V100S | 1.9382 |
## 应用场景
### 算法类别
`语音合成`
### 热点应用行业
`金融,电商,教育,制造,医疗,能源`
## 源码仓库及问题反馈
- http://developer.hpccube.com/codes/modelzoo/stabletts_pytorch.git
## 参考资料
- https://github.com/KdaiP/StableTTS.git
- https://stabilityai-public-packages.s3.us-west-2.amazonaws.com/Stable+Diffusion+3+Paper.pdf
<div align="center">
# StableTTS
Next-generation TTS model using flow-matching and DiT, inspired by [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3).
</div>
## Introduction
As the first open-source TTS model that tried to combine flow-matching and DiT, StableTTS is a fast and lightweight TTS model for chinese and english speech generation. It has only 10M parameters.
**Huggingface demo:** [chinese_version](https://huggingface.co/spaces/KdaiP/StableTTS_zh-demo) [english_version](https://huggingface.co/spaces/KdaiP/StableTTS_en-demo)
## Pretrained models
We provide pretrained models ready for inference, finetuning and webui. Simply download and place the models in the `./checkpoints` directory to get started.
| Model Name | Task Details | Dataset | Download Link |
|:----------:|:------------:|:-------------:|:-------------:|
| StableTTS | text to mel | 400h english | [🤗](https://huggingface.co/KdaiP/StableTTS/blob/main/checkpoint-en_0.pt)|
| StableTTS | text to mel | 100h chinese | [🤗](https://huggingface.co/KdaiP/StableTTS/blob/main/checkpoint-zh_0.pt)|
| Vocos | mel to wav | 2k english + chinese + japanese | [🤗](https://huggingface.co/KdaiP/StableTTS/blob/main/vocoder.pt)|
**Larger models, better pretrained models and multilingual models will comming soon...**
## Installation
1. **Set up pytorch**: Follow the [official PyTorch guide](https://pytorch.org/get-started/locally/) to install pytorch and torchaudio. We recommend using the latest version for optimal performance.
2. **Install Dependencies**: Run the following command to install the required Python packages:
```bash
pip install -r requirements.txt
```
## Inference
For detailed inference instructions, please refer to `inference.ipynb`
We also provide a webui based on gradio, please refer to `webui.py`
## Training
Training your models with StableTTS is designed to be straightforward and efficient. Here’s how to get started:
### Preparing Your Data
1. **Generate Text and Audio pairs**: Generate the text and audio pair filelist as `./filelists/example.txt`. Some recipes of open-source datasets could be found in `./recipes`.
2. **Run Preprocessing**: Adjust the `DataConfig` in `preprocess.py` to set your input and output paths, then run the script. This will process the audio and text according to your list, outputting a JSON file with paths to mel features and phonemes. **Note: Ensure to change `language = 'chinese'` in `DataConfig` for English or Japanese text processing.**
Note: Since we use `reference encoder` to capture speaker identity when training, there is no need for a speaker ID in multispeaker synthesis and training.
### Start training
1. **Adjust Training Configuration**: In `config.py`, modify `TrainConfig` to set your file list path and adjust training parameters as needed.
2. **Start the Training Process**: Launch `train.py` to start training your model.
Note: For finetuning, download the pretrained model and place it in the `model_save_path` directory specified in `TrainConfig`. Training script will automatically detect and load the pretrained checkpoint.
### Experiment with Configurations
Feel free to explore and modify settings in `config.py` to modify the hyperparameters!
## Model structure
<div align="center">
<p style="text-align: center;">
<img src="./figures/structure.jpg" height="512"/>
</p>
</div>
- We use the Diffusion Convolution Transformer block from [Hierspeech++](https://github.com/sh-lee-prml/HierSpeechpp), which is a combination of original [DiT](https://github.com/sh-lee-prml/HierSpeechpp) and [FFT](https://arxiv.org/pdf/1905.09263.pdf)(Feed forward Transformer from fastspeech) for better prosody.
- In flow-matching decoder, we add a [FiLM layer](https://arxiv.org/abs/1709.07871) before DiT block to condition timestep embedding into model. We also add three ConvNeXt blocks before DiT. We found it helps with model convergence and better sound quality
## References
The development of our models heavily relies on insights and code from various projects. We express our heartfelt thanks to the creators of the following:
### Direct Inspirations
[Matcha TTS](https://github.com/shivammehta25/Matcha-TTS): Essential flow-matching code.
[Grad TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): Diffusion model structure.
[Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3): Idea of combining flow-matching and DiT.
[Vits](https://github.com/jaywalnut310/vits): Code style and MAS insights, DistributedBucketSampler.
### Additional References:
[plowtts-pytorch](https://github.com/p0p4k/pflowtts_pytorch): codes of MAS in training
[Bert-VITS2](https://github.com/Plachtaa/VITS-fast-fine-tuning) : numba version of MAS and modern pytorch codes of Vits
[fish-speech](https://github.com/fishaudio/fish-speech): dataclass usage and mel-spectrogram transforms using torchaudio
[gpt-sovits](https://github.com/RVC-Boss/GPT-SoVITS): melstyle encoder for voice clone
[diffsinger](https://github.com/openvpi/DiffSinger): chinese three section phoneme scheme for chinese g2p
[coqui xtts](https://huggingface.co/spaces/coqui/xtts): gradio webui
## TODO
- [ ] Release pretrained models.
- [ ] Provide detailed finetuning instructions.
- [x] Support Japanese language.
- [ ] User friendly preprocess and inference script.
- [ ] Enhance documentation and citations.
- [ ] Add chinese version of readme.
- [ ] Release multilingual checkpoint.
## Disclaimer
Any organization or individual is prohibited from using any technology in this repo to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
\ No newline at end of file
File added
from dataclasses import dataclass
@dataclass
class MelConfig:
sample_rate: int = 44100
n_fft: int = 2048
win_length: int = 2048
hop_length: int = 512
f_min: float = 0.0
f_max: float = None
pad: int = 0
n_mels: int = 128
power: float = 1.0
normalized: bool = False
center: bool = False
pad_mode: str = "reflect"
mel_scale: str = "htk"
def __post_init__(self):
if self.pad == 0:
self.pad = (self.n_fft - self.hop_length) // 2
@dataclass
class ModelConfig:
hidden_channels: int = 192
filter_channels: int = 512
n_heads: int = 2
n_enc_layers: int = 3
n_dec_layers: int = 2
kernel_size: int = 3
p_dropout: int = 0.1
gin_channels: int = 192
@dataclass
class TrainConfig:
train_dataset_path: str = 'filelists/filelist.json'
test_dataset_path: str = 'filelists/filelist.json'
batch_size: int = 48
learning_rate: float = 1e-4
num_epochs: int = 10000
model_save_path: str = './checkpoints'
log_dir: str = './runs'
log_interval: int = 128
save_interval: int = 15
warmup_steps: int = 200
@dataclass
class VocosConfig:
input_channels: int = 128
dim: int = 512
intermediate_dim: int = 1536
num_layers: int = 8
import os
import json
import torch
from torch.utils.data import Dataset
from text import cleaned_text_to_sequence
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
class StableDataset(Dataset):
def __init__(self, filelist_path, hop_length):
self.filelist_path = filelist_path
self.hop_length = hop_length
self._load_filelist(filelist_path)
def _load_filelist(self, filelist_path):
filelist, lengths = [], []
with open(filelist_path, 'r', encoding='utf-8') as f:
for line in f:
line = json.loads(line.strip())
filelist.append((line['mel_path'], line['phone']))
lengths.append(os.path.getsize(line['audio_path']) // (2 * self.hop_length))
self.filelist = filelist
self.lengths = lengths
def __len__(self):
return len(self.filelist)
def __getitem__(self, idx):
mel_path, phone = self.filelist[idx]
mel = torch.load(mel_path, map_location='cpu')
phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
return mel, phone
def collate_fn(batch):
texts = [item[1] for item in batch]
mels = [item[0] for item in batch]
text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
# pad to the same length
texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
return texts_padded, text_lengths, mels_padded, mel_lengths
\ No newline at end of file
import torch
# reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
It removes samples which are not included in the boundaries.
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
self.buckets, self.num_samples_per_bucket = self._create_buckets()
self.total_size = sum(self.num_samples_per_bucket)
self.num_samples = self.total_size // self.num_replicas
def _create_buckets(self):
buckets = [[] for _ in range(len(self.boundaries) - 1)]
for i in range(len(self.lengths)):
length = self.lengths[i]
idx_bucket = self._bisect(length)
if idx_bucket != -1:
buckets[idx_bucket].append(i)
# from https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/data_utils.py
# avoid "integer division or modulo by zero" error for very small dataset
try:
for i in range(len(buckets) - 1, 0, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
assert all(len(bucket) > 0 for bucket in buckets)
# When one bucket is not traversed
except Exception as e:
print('Bucket warning ', e)
for i in range(len(buckets) - 1, -1, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
num_samples_per_bucket = []
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
if self.shuffle:
for bucket in self.buckets:
indices.append(torch.randperm(len(bucket), generator=g).tolist())
else:
for bucket in self.buckets:
indices.append(list(range(len(bucket))))
batches = []
for i in range(len(self.buckets)):
bucket = self.buckets[i]
len_bucket = len(bucket)
ids_bucket = indices[i]
num_samples_bucket = self.num_samples_per_bucket[i]
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch)
if self.shuffle:
batch_ids = torch.randperm(len(batches), generator=g).tolist()
batches = [batches[i] for i in batch_ids]
self.batches = batches
assert len(self.batches) * self.batch_size == self.num_samples
return iter(self.batches)
def _bisect(self, x, lo=0, hi=None):
if hi is None:
hi = len(self.boundaries) - 1
if hi > lo:
mid = (hi + lo) // 2
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
return mid
elif x <= self.boundaries[mid]:
return self._bisect(x, lo, mid)
else:
return self._bisect(x, mid + 1, hi)
else:
return -1
def __len__(self):
return self.num_samples // self.batch_size
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment