Commit b309ea1b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
docker run -it --shm-size=32G -v $PWD/StableTTS:/home/StableTTS -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name stabletts1 fea033ed400a bash
# docker run -it --shm-size=32G -v $PWD/StableTTS:/home/StableTTS -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name stabletts ffa1f63239fc bash
# python -m torch.utils.collect_env
./audio1.wav|你好,世界。
./audio2.wav|Hello, world.
\ No newline at end of file
File added
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import asdict\n",
"import torch\n",
"import torchaudio\n",
"from IPython.display import Audio, display\n",
"\n",
"from utils.audio import LogMelSpectrogram\n",
"from config import ModelConfig, VocosConfig, MelConfig\n",
"from models.model import StableTTS\n",
"from vocos_pytorch.models.model import Vocos\n",
"from text.mandarin import chinese_to_cnm3\n",
"from text.english import english_to_ipa2\n",
"from text.japanese import japanese_to_ipa2\n",
"from text import cleaned_text_to_sequence\n",
"from text import symbols\n",
"from datas.dataset import intersperse\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"g2p_mapping = {\n",
" 'chinese': chinese_to_cnm3,\n",
" 'japanese': japanese_to_ipa2,\n",
" 'english': english_to_ipa2,\n",
"}\n",
"\n",
"@ torch.inference_mode()\n",
"def inference(text: str, ref_audio: torch.Tensor, tts_model: StableTTS, mel_extractor: LogMelSpectrogram, vocoder: Vocos, phonemizer, sample_rate: int, step: int=10) -> torch.Tensor:\n",
" x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)\n",
" x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)\n",
" waveform, sr = torchaudio.load(ref_audio)\n",
" if sr != sample_rate:\n",
" waveform = torchaudio.functional.resample(waveform, sr, sample_rate)\n",
" y = mel_extractor(waveform).to(device)\n",
" mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']\n",
" audio = vocoder(mel)\n",
" return audio.cpu(), mel.cpu()\n",
"\n",
"def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path: str, vocoder_checkpoint_path: str):\n",
" tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))\n",
" mel_extractor = LogMelSpectrogram(mel_config)\n",
" vocoder = Vocos(vocoder_config, mel_config)\n",
" tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))\n",
" tts_model.to(device)\n",
" tts_model.eval()\n",
" vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))\n",
" vocoder.to(device)\n",
" vocoder.eval()\n",
" return tts_model, mel_extractor, vocoder\n",
"\n",
"tts_model_config = ModelConfig()\n",
"mel_config = MelConfig()\n",
"vocoder_config = VocosConfig()\n",
"\n",
"tts_checkpoint_path = './checkpoints/checkpoint-zh_0.pt'\n",
"vocoder_checkpoint_path = './checkpoints/vocoder.pt'\n",
"\n",
"tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)\n",
"total_params = sum(p.numel() for p in tts_model.parameters()) / 1e6\n",
"print(f'Total params: {total_params} M')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"language = 'chinese' # now we only support chinese, japanese and english\n",
"\n",
"phonemizer = g2p_mapping.get(language)\n",
"\n",
"text = '你好,世界!'\n",
"ref_audio = './audio.wav'\n",
"output, mel = inference(text, ref_audio, tts_model, mel_extractor, vocoder, phonemizer, mel_config.sample_rate, 15)\n",
"display(Audio(ref_audio))\n",
"display(Audio(output, rate=mel_config.sample_rate))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lxn_vits",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from dataclasses import asdict
import torch
import torchaudio
from IPython.display import Audio, display
from utils.audio import LogMelSpectrogram
from config import ModelConfig, VocosConfig, MelConfig
from models.model import StableTTS
from vocos_pytorch.models.model import Vocos
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2
from text import cleaned_text_to_sequence
from text import symbols
from datas.dataset import intersperse
from scipy.io import wavfile
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
g2p_mapping = {
'chinese': chinese_to_cnm3,
'japanese': japanese_to_ipa2,
'english': english_to_ipa2,
}
@ torch.inference_mode()
def inference(text: str, ref_audio: torch.Tensor, tts_model: StableTTS, mel_extractor: LogMelSpectrogram, vocoder: Vocos, phonemizer, sample_rate: int, step: int=10) -> torch.Tensor:
x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
waveform, sr = torchaudio.load(ref_audio)
if sr != sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
y = mel_extractor(waveform).to(device)
mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']
audio = vocoder(mel)
return audio.cpu(), mel.cpu()
def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path: str, vocoder_checkpoint_path: str):
tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
mel_extractor = LogMelSpectrogram(mel_config)
vocoder = Vocos(vocoder_config, mel_config)
tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
tts_model.to(device)
tts_model.eval()
vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
vocoder.to(device)
vocoder.eval()
return tts_model, mel_extractor, vocoder
tts_model_config = ModelConfig()
mel_config = MelConfig()
vocoder_config = VocosConfig()
tts_checkpoint_path = './checkpoints/checkpoint-zh_0.pt'
vocoder_checkpoint_path = './checkpoints/vocoder.pt'
tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
total_params = sum(p.numel() for p in tts_model.parameters()) / 1e6
print(f'Total params: {total_params} M')
language = 'chinese' # now we only support chinese, japanese and english
phonemizer = g2p_mapping.get(language)
text = '你好,世界!'
ref_audio = './audio.wav'
# start_time = time.time()
output, mel = inference(text, ref_audio, tts_model, mel_extractor, vocoder, phonemizer, mel_config.sample_rate, 15)
# print("infer time:", time.time() - start_time, "s")
display(Audio(ref_audio))
display(Audio(output, rate=mel_config.sample_rate))
wavfile.write('generate.wav', mel_config.sample_rate, output.numpy().T)
# 模型编码
modelCode=615
# 模型名称
modelName=StableTTS_pytorch
# 模型描述
modelDescription=StableTTS是一款用于中英文语音生成的快速轻量级TTS模型,只有10M参数。
# 应用场景
appScenario=推理,训练,金融,电商,教育,制造,医疗,能源
# 框架类型
frameType=pytorch
# References:
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
# https://github.com/jaywalnut310/vits/blob/main/attentions.py
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = nn.Dropout(p_dropout)
self.act1 = nn.GELU(approximate="tanh")
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = self.act1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.k_channels = channels // n_heads
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
# from https://nn.labml.ai/transformers/rope/index.html
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
torch.nn.init.xavier_uniform_(self.conv_q.weight)
torch.nn.init.xavier_uniform_(self.conv_k.weight)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(x)
v = self.conv_v(x)
x = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
key = self.key_rotary_pe(key)
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output
# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
class DiTConVBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
self.adaLN_modulation = nn.Sequential(
nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
)
def forward(self, x, c, x_mask):
"""
Args:
x : [batch_size, channel, time]
c : [batch_size, channel]
x_mask : [batch_size, 1, time]
return the same shape as x
"""
x = x * x_mask
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
# attn_mask = attn_mask.to(torch.bool)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
# no condition version
# x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
# x = x + self.mlp(self.norm1(x.transpose(1,2)).transpose(1,2), x_mask)
return x
@staticmethod
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
Rotary encoding transforms pairs of features by rotating in the 2D plane.
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
by an angle depending on the position of the token.
"""
def __init__(self, d: int, base: int = 10_000):
r"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = int(d)
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
r"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)
import torch
import torch.nn as nn
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_1 = nn.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_2 = nn.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g):
x = x.detach()
x = x + self.cond(g.unsqueeze(2).detach())
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x.transpose(1,2)).transpose(1,2)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x.transpose(1,2)).transpose(1,2)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
return loss
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.dit import DiTConVBlock
class DitWrapper(nn.Module):
""" add FiLM layer to condition time embedding to DiT """
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
super().__init__()
self.time_fusion = FiLMLayer(hidden_channels, time_channels)
self.conv1 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
self.conv2 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
self.conv3 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
self.block = DiTConVBlock(hidden_channels, hidden_channels, num_heads, kernel_size, p_dropout, gin_channels)
def forward(self, x, c, t, x_mask):
x = self.time_fusion(x, t) * x_mask
x = self.conv1(x, c, x_mask)
x = self.conv2(x, c, x_mask)
x = self.conv3(x, c, x_mask)
x = self.block(x, c, x_mask)
return x
class FiLMLayer(nn.Module):
"""
Feature-wise Linear Modulation (FiLM) layer
Reference: https://arxiv.org/abs/1709.07871
"""
def __init__(self, in_channels, cond_channels):
super(FiLMLayer, self).__init__()
self.in_channels = in_channels
self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
def forward(self, x, c):
gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
return gamma * x + beta
class ConvNeXtBlock(nn.Module):
def __init__(self, in_channels, filter_channels, gin_channels):
super().__init__()
self.dwconv = nn.Conv1d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
self.norm = StyleAdaptiveLayerNorm(in_channels, gin_channels)
self.pwconv = nn.Sequential(nn.Linear(in_channels, filter_channels),
nn.GELU(),
nn.Linear(filter_channels, in_channels))
def forward(self, x, c, x_mask) -> torch.Tensor:
residual = x
x = self.dwconv(x) * x_mask
x = self.norm(x.transpose(1, 2), c)
x = self.pwconv(x).transpose(1, 2)
x = residual + x
return x * x_mask
class StyleAdaptiveLayerNorm(nn.Module):
def __init__(self, in_channels, cond_channels):
"""
Style Adaptive Layer Normalization (SALN) module.
Parameters:
in_channels: The number of channels in the input feature maps.
cond_channels: The number of channels in the conditioning input.
"""
super(StyleAdaptiveLayerNorm, self).__init__()
self.in_channels = in_channels
self.saln = nn.Linear(cond_channels, in_channels * 2, 1)
self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.saln.bias.data[:self.in_channels], 1)
nn.init.constant_(self.saln.bias.data[self.in_channels:], 0)
def forward(self, x, c):
gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1)
return gamma * self.norm(x) + beta
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_channels, filter_channels),
nn.SiLU(inplace=True),
nn.Linear(filter_channels, out_channels)
)
def forward(self, x):
return self.layer(x)
# reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
class Decoder(nn.Module):
def __init__(self, hidden_channels, out_channels, filter_channels, dropout=0.05, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0):
super().__init__()
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.time_embeddings = SinusoidalPosEmb(hidden_channels)
self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for block in self.blocks:
nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
def forward(self, x, mask, mu, t, c):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
c (_type_): shape (batch_size, gin_channels)
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_mlp(self.time_embeddings(t))
x = torch.cat((x, mu), dim=1)
for block in self.blocks:
x = block(x, c, t, mask)
output = self.final_proj(x * mask)
return output * mask
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.estimator import Decoder
# copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
class CFMDecoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
super().__init__()
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.gin_channels = gin_channels
self.sigma_min = 1e-4
self.estimator = Decoder(hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
c (torch.Tensor, optional): shape: (batch_size, gin_channels)
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, c=c)
def solve_euler(self, x, t_span, mu, mask, c):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
c (torch.Tensor, optional): speaker condition.
shape: (batch_size, gin_channels)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, c)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, c):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
c (torch.Tensor, optional): speaker condition.
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss, y
import math
import torch
import torch.nn as nn
import monotonic_align
from models.text_encoder import TextEncoder
from models.flow_matching import CFMDecoder
from models.reference_encoder import MelStyleEncoder
from models.duration_predictor import DurationPredictor, duration_loss
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def convert_pad_shape(pad_shape):
inverted_shape = pad_shape[::-1]
pad_shape = [item for sublist in inverted_shape for item in sublist]
return pad_shape
def generate_path(duration, mask):
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
class StableTTS(nn.Module):
def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
super().__init__()
self.n_vocab = n_vocab
self.mel_channels = mel_channels
self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3)
self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels)
self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
2. decoder outputs
3. generated alignment
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
y (torch.Tensor): mel spectrogram of reference audio
shape: (batch_size, mel_channels, time)
length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa.
Returns:
dict: {
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Average mel spectrogram generated by the encoder
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Refined mel spectrogram improved by the CFM
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
# Alignment map between text and mel spectrogram
"""
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
c = self.ref_encoder(y, None)
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
logw = self.dp(x, x_mask, c)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample tracing the probability flow
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
return {
"encoder_outputs": encoder_outputs,
"decoder_outputs": decoder_outputs,
"attn": attn[:, :, :y_max_length],
}
def forward(self, x, x_lengths, y, y_lengths):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
y (torch.Tensor): batch of corresponding mel-spectrograms.
shape: (batch_size, n_feats, max_mel_length)
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
shape: (batch_size,)
"""
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
c = self.ref_encoder(y, y_mask)
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
logw = self.dp(x, x_mask, c)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
# I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS
# so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works
# Welcome everyone to solve this problem QAQ
with torch.no_grad():
# const = -0.5 * math.log(2 * math.pi) * self.n_feats
# const = -0.5 * math.log(2 * math.pi) * self.mel_channels
# factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
# y_square = torch.matmul(factor.transpose(1, 2), y**2)
# y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
# mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
# log_prior = y_square - y_mu_double + mu_square + const
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
# s_p_sq_r = torch.exp(-2 * logx)
neg_cent1 = torch.sum(
-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True
)
# neg_cent1 = torch.sum(
# -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True
# ) # [b, 1, t_s]
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
neg_cent4 = torch.sum(
-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True
)
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = (
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
)
# attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
# attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
# logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
# Align encoded text with mel-spectrogram and get mu_y segment
attn = attn.squeeze(1).transpose(1,2)
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c)
# diff_loss = torch.tensor([0], device=mu_y.device)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
return dur_loss, diff_loss, prior_loss, attn
\ No newline at end of file
import torch
import torch.nn as nn
class Conv1dGLU(nn.Module):
"""
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
"""
def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__()
self.out_channels = out_channels
self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.conv1(x)
x1, x2 = torch.split(x, self.out_channels, dim=1)
x = x1 * torch.sigmoid(x2)
x = residual + self.dropout(x)
return x
# modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
class MelStyleEncoder(nn.Module):
"""MelStyleEncoder"""
def __init__(
self,
n_mel_channels=80,
style_hidden=128,
style_vector_dim=256,
style_kernel_size=5,
style_head=2,
dropout=0.1,
):
super(MelStyleEncoder, self).__init__()
self.in_dim = n_mel_channels
self.hidden_dim = style_hidden
self.out_dim = style_vector_dim
self.kernel_size = style_kernel_size
self.n_head = style_head
self.dropout = dropout
self.spectral = nn.Sequential(
nn.Linear(self.in_dim, self.hidden_dim),
nn.Mish(inplace=True),
nn.Dropout(self.dropout),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.Mish(inplace=True),
nn.Dropout(self.dropout),
)
self.temporal = nn.Sequential(
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = nn.MultiheadAttention(
self.hidden_dim,
self.n_head,
self.dropout,
batch_first=True
)
self.fc = nn.Linear(self.hidden_dim, self.out_dim)
def temporal_avg_pool(self, x, mask=None):
if mask is None:
return torch.mean(x, dim=1)
else:
len_ = (~mask).sum(dim=1).unsqueeze(1).type_as(x)
return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / len_
def forward(self, x, x_mask=None):
x = x.transpose(1, 2)
# spectral
x = self.spectral(x)
# temporal
x = x.transpose(1, 2)
x = self.temporal(x)
x = x.transpose(1, 2)
# self-attention
if x_mask is not None:
x_mask = ~x_mask.squeeze(1).to(torch.bool)
x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask)
# fc
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=x_mask)
return w
\ No newline at end of file
import torch
import torch.nn as nn
from models.dit import DiTConVBlock
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py
class TextEncoder(nn.Module):
def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.scale = self.hidden_channels ** 0.5
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for block in self.encoder:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
x = self.emb(x) * self.scale # [b, t, h]
x = x.transpose(1, -1) # [b, h, t]
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
for layer in self.encoder:
x = layer(x, c, x_mask)
mu_x = self.proj(x) * x_mask
return x, mu_x, x_mask
from numpy import zeros, int32, float32
from torch import from_numpy
from .core import maximum_path_jit
def maximum_path(neg_cent, mask):
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
path = zeros(neg_cent.shape, dtype=int32)
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
return from_numpy(path).to(device=device, dtype=dtype)
import numba
@numba.jit(
numba.void(
numba.int32[:, :, ::1],
numba.float32[:, :, ::1],
numba.int32[::1],
numba.int32[::1],
),
nopython=True,
nogil=True,
)
def maximum_path_jit(paths, values, t_ys, t_xs):
b = paths.shape[0]
max_neg_val = -1e9
for i in range(int(b)):
path = paths[i]
value = values[i]
t_y = t_ys[i]
t_x = t_xs[i]
v_prev = v_cur = 0.0
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (
index == y or value[y - 1, index] < value[y - 1, index - 1]
):
index = index - 1
import os
import json
from dataclasses import dataclass
from tqdm import tqdm
import torch
from torch.multiprocessing import Pool, set_start_method
import torchaudio
from config import MelConfig, TrainConfig
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2
from utils.audio import LogMelSpectrogram, load_and_resample_audio
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@dataclass
class DataConfig:
input_filelist_path = './filelists/filelist.txt' # a filelist contains 'audiopath | text'
output_filelist_path = './filelists/filelist.json' # path to save filelist
output_feature_path = './stableTTS_datasets' # path to save resampled audios and mel features
language = 'chinese' # chinese, japanese or english
resample = False # waveform is not used in training. However, it is used to calculate length for DistributedBucketSampler in training. Different samplerate or format may cause wrong bucket.
g2p_mapping = {
'chinese': chinese_to_cnm3,
'japanese': japanese_to_ipa2,
'english': english_to_ipa2,
}
data_config = DataConfig()
train_config = TrainConfig()
mel_config = MelConfig()
input_filelist_path = data_config.input_filelist_path
output_filelist_path = data_config.output_filelist_path
output_feature_path = data_config.output_feature_path
# Ensure output directories exist
output_mel_dir = os.path.join(output_feature_path, 'mels')
os.makedirs(output_mel_dir, exist_ok=True)
if data_config.resample:
output_wav_dir = os.path.join(output_feature_path, 'waves')
os.makedirs(output_wav_dir, exist_ok=True)
os.makedirs(os.path.dirname(output_filelist_path), exist_ok=True)
mel_extractor = LogMelSpectrogram(mel_config).to(device)
g2p = g2p_mapping.get(data_config.language)
def load_filelist(path) -> list:
file_list = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
audio_path, text = line.strip().split('|', maxsplit=1)
file_list.append((str(idx), audio_path, text))
return file_list
@ torch.inference_mode()
def process_filelist(line) -> str:
idx, audio_path, text = line
audio = load_and_resample_audio(audio_path, mel_config.sample_rate, device=device) # shape: [1, time]
if audio is not None:
# get output path
audio_name, _ = os.path.splitext(os.path.basename(audio_path))
try:
phone = g2p(text)
if len(phone) > 0:
mel = mel_extractor(audio.to(device)).cpu().squeeze(0) # shape: [n_mels, time // hop_length]
output_mel_path = os.path.join(output_mel_dir, f'{idx}_{audio_name}.pt')
torch.save(mel, output_mel_path)
if data_config.resample:
audio_path = os.path.join(output_wav_dir, f'{idx}_{audio_name}.wav')
torchaudio.save(audio_path, audio.cpu(), mel_config.sample_rate)
return json.dumps({'mel_path': output_mel_path, 'phone': phone, 'audio_path': audio_path, 'text': text}, ensure_ascii=False, allow_nan=False)
except Exception as e:
print(f'Error processing {audio_path}: {str(e)}')
def main():
set_start_method('spawn') # CUDA must use spawn method
input_filelist = load_filelist(input_filelist_path)
results = []
with Pool(processes=2) as pool:
for result in tqdm(pool.imap(process_filelist, input_filelist), total=len(input_filelist)):
if result is not None:
results.append(f'{result}\n')
# save filelist
with open(output_filelist_path, 'w', encoding='utf-8') as f:
f.writelines(results)
print(f"filelist file has been saved to {output_filelist_path}")
# faster and use much less CPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == '__main__':
main()
import os
import re
from dataclasses import dataclass
import concurrent.futures
from tqdm.auto import tqdm
# download_link: https://www.openslr.org/93/
@dataclass
class DataConfig:
dataset_path = './raw_datasets/Aishell3/train/wav'
txt_path = './raw_datasets/Aishell3/train/content.txt'
output_filelist_path = './filelists/aishell3.txt'
data_config = DataConfig()
def process_filelist(line):
dir_name, audio_path, text = line
input_audio_path = os.path.abspath(os.path.join(data_config.dataset_path, dir_name, audio_path))
if os.path.exists(input_audio_path):
return f'{input_audio_path}|{text}\n'
if __name__ == '__main__':
filelist = []
results = []
with open(data_config.txt_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
audio_path, text = line.strip().split(maxsplit=1)
dir_name = audio_path[:7]
text = re.sub(r'[a-zA-Z0-9\s]', '', text) # remove pinyin and tone
filelist.append((dir_name, audio_path, text))
with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(process_filelist, line) for line in filelist]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(filelist)):
result = future.result()
if result is not None:
results.append(result)
# make sure that the parent dir exists, raising error at the last step is quite terrible OVO
os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
f.writelines(results)
\ No newline at end of file
import os
import re
from dataclasses import dataclass
import concurrent.futures
from tqdm.auto import tqdm
# submit the form on: https://www.data-baker.com/data/index/TNtts/
# then you will get the download link
@dataclass
class DataConfig:
dataset_path = './raw_datasets/BZNSYP/Wave'
txt_path = './raw_datasets/BZNSYP/ProsodyLabeling/000001-010000.txt'
output_filelist_path = './filelists/bznsyp.txt'
data_config = DataConfig()
def process_filelist(line):
audio_name, text = line.split('\t')
text = re.sub('[#\d]+', '', text) # remove '#' and numbers
input_audio_path = os.path.abspath(os.path.join(data_config.dataset_path, f'{audio_name}.wav'))
if os.path.exists(input_audio_path):
return f'{input_audio_path}|{text}\n'
if __name__ == '__main__':
filelist = []
results = []
with open(data_config.txt_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if idx % 2 == 0:
filelist.append(line.strip())
with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(process_filelist, line) for line in filelist]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(filelist)):
result = future.result()
if result is not None:
results.append(result)
# make sure that the parent dir exists, raising error at the last step is quite terrible OVO
os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
f.writelines(results)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment