Commit b309ea1b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from dataclasses import dataclass
@dataclass
class MelConfig:
sample_rate: int = 44100
n_fft: int = 2048
win_length: int = 2048
hop_length: int = 512
f_min: float = 0.0
f_max: float = None
pad: int = 0
n_mels: int = 128
power: float = 1.0
normalized: bool = False
center: bool = False
pad_mode: str = "reflect"
mel_scale: str = "htk"
def __post_init__(self):
if self.pad == 0:
self.pad = (self.n_fft - self.hop_length) // 2
@dataclass
class VocosConfig:
input_channels: int = 128
dim: int = 512
intermediate_dim: int = 1536
num_layers: int = 8
@dataclass
class TrainConfig:
train_dataset_path: str = './filelists/filelist.txt'
test_dataset_path: str = './filelists/filelist.txt'
batch_size: int = 22
learning_rate: float = 1e-4
num_epochs: int = 10000
model_save_path: str = './checkpoints'
log_dir: str = './runs'
log_interval: int = 128
warmup_steps: int = 200
segment_size = 20480
\ No newline at end of file
import os
import torch
import torchaudio
from torch.utils.data import Dataset
from utils.audio import LogMelSpectrogram
from config import MelConfig
class VocosDataset(Dataset):
def __init__(self, filelist_path, segment_size: int, mel_config: MelConfig):
self.filelist_path = filelist_path
self.segment_size = segment_size
self.mel_extractor = LogMelSpectrogram(mel_config)
self.filelist = self._load_filelist(filelist_path)
def _load_filelist(self, filelist_path):
with open(filelist_path, 'r', encoding='utf-8') as f:
filelist = [line.strip() for line in f if os.path.exists(line.strip())]
return filelist
def __len__(self):
return len(self.filelist)
def __getitem__(self, idx):
audio, _ = torchaudio.load(self.filelist[idx])
# select a random segment from the audio file
# audio is validated in the preprocess stage, so we skip checking sample_rate and padding short audio
start_index = torch.randint(0, audio.size(-1) - self.segment_size + 1, (1,)).item()
audio = audio[:, start_index:start_index + self.segment_size] # shape: [1, segment_size]
mel = self.mel_extractor(audio).squeeze(0) # shape: [n_mels, segment_size // hop_length]
return audio, mel
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"from IPython.display import Audio, display\n",
"\n",
"from models.model import Vocos\n",
"from utils.audio import LogMelSpectrogram\n",
"from config import MelConfig, VocosConfig\n",
"\n",
"from pathlib import Path\n",
"import random\n",
"\n",
"def load_and_resample_audio(audio_path, target_sr):\n",
" y, sr = torchaudio.load(audio_path)\n",
" if y.size(0) > 1:\n",
" y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]\n",
" if sr != target_sr:\n",
" y = torchaudio.functional.resample(y, sr, target_sr)\n",
" return y\n",
"\n",
"device = 'cpu'\n",
"\n",
"mel_config = MelConfig()\n",
"vocos_config = VocosConfig()\n",
"\n",
"mel_extractor = LogMelSpectrogram(mel_config)\n",
"model = Vocos(vocos_config, mel_config).to(device)\n",
"model.load_state_dict(torch.load('./checkpoints/generator_0.pt', map_location='cpu'))\n",
"model.eval()\n",
"\n",
"audio_paths = list(Path('./audios').rglob('*.wav'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"audio_path = random.choice(audio_paths)\n",
"with torch.inference_mode():\n",
" audio = load_and_resample_audio(audio_path, mel_config.sample_rate).to(device)\n",
" mel = mel_extractor(audio)\n",
" recon_audio = model(mel)\n",
"display(Audio(audio, rate=mel_config.sample_rate))\n",
"display(Audio(recon_audio, rate=mel_config.sample_rate))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lxn_vits",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from typing import Optional
import torch
from torch import nn
from .module import ConvNeXtBlock
class VocosBackbone(nn.Module):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x)
x = self.final_layer_norm(x.transpose(1, 2))
return x
\ No newline at end of file
from typing import List, Tuple
import torch
from torch import nn
from torch import Tensor
from torch.nn import Conv2d
from torch.nn.utils import weight_norm
from torchaudio.transforms import Spectrogram
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11)):
super().__init__()
self.discriminators = nn.ModuleList([DiscriminatorP(period=p) for p in periods])
def forward(self, y: Tensor, y_hat: Tensor):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorP(nn.Module):
def __init__(
self,
period: int,
in_channels: int = 1,
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
]
)
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
self.lrelu_slope = lrelu_slope
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for i, l in enumerate(self.convs):
x = l(x)
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
if i > 0:
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
Args:
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
"""
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorR(window_length=w) for w in fft_sizes]
)
def forward(self, y: Tensor, y_hat: Tensor) -> Tuple[List[Tensor], List[Tensor], List[List[Tensor]], List[List[Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(
self,
window_length: int,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x):
x = x.squeeze(1)
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = x.permute(0, 3, 2, 1) # b f t c -> b c t f
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x: Tensor):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
x = self.conv_post(x)
fmap.append(x)
return x, fmap
\ No newline at end of file
import torch
from torch import nn
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
assert (window_envelope > 1e-11).all()
y = y / window_envelope
return y
class ISTFTHead(nn.Module):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
\ No newline at end of file
import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss*2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses
\ No newline at end of file
from dataclasses import dataclass, asdict
import torch
from torch import nn
from torch import Tensor
from .head import ISTFTHead
from .backbone import VocosBackbone
from config import MelConfig, VocosConfig
class Vocos(nn.Module):
def __init__(self, vocos_config: VocosConfig, mel_config: MelConfig):
super().__init__()
self.backbone = VocosBackbone(**asdict(vocos_config))
self.head = ISTFTHead(vocos_config.dim, mel_config.n_fft, mel_config.hop_length)
def forward(self, x: Tensor) -> Tensor:
x = self.backbone(x)
x = self.head(x)
return x
import torch
from torch import nn
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: float,
):
super().__init__()
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
\ No newline at end of file
import glob
import os
from tqdm import tqdm
from dataclasses import dataclass
import torch
from torch import Tensor
from torch.multiprocessing import Pool, set_start_method
import torchaudio
from config import MelConfig, TrainConfig
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@dataclass
class DataConfig:
audio_dir = './audios' # path to audios
output_dir = './vocos_datasets' # path to save processed audios
filelist_path = './filelists/filelist.txt' # path to save filelist
data_config = DataConfig()
train_config = TrainConfig()
mel_config = MelConfig()
audio_dir = data_config.audio_dir
output_dir = data_config.output_dir
filelist_path = data_config.filelist_path
segment_size = train_config.segment_size
output_audio_dir = os.path.join(output_dir, 'audios')
# Ensure output directories exist
os.makedirs(output_audio_dir, exist_ok=True)
os.makedirs(os.path.dirname(filelist_path), exist_ok=True)
def load_and_resample_audio(audio_path, target_sr, segment_size, device='cpu') -> Tensor:
try:
y, sr = torchaudio.load(audio_path)
except Exception as e:
print(str(e))
return None
y.to(device)
# Convert to mono
if y.size(0) > 1:
y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
# resample audio to target sample_rate
if sr != target_sr:
y = torchaudio.functional.resample(y, sr, target_sr)
if y.size(-1) < segment_size:
y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0)
return y
def find_audio_files(directory) -> list:
extensions = ['wav', 'mp3', 'flac']
files_found = []
for extension in extensions:
pattern = os.path.join(directory, '**', f'*.{extension}')
files_found.extend(glob.glob(pattern, recursive=True))
return files_found
@ torch.inference_mode()
def process_audio(audio_path):
audio = load_and_resample_audio(audio_path, mel_config.sample_rate, segment_size, device=device) # shape: [1, time]
if audio is not None:
# get output path
audio_name, _ = os.path.splitext(os.path.basename(audio_path))
output_audio_path = os.path.join(output_audio_dir, audio_name + '.wav')
# save resampled audio and mel features
torchaudio.save(output_audio_path, audio.cpu(), mel_config.sample_rate)
return output_audio_path
def main():
set_start_method('spawn') # CUDA must use spawn method
audio_files = find_audio_files(audio_dir)
results = []
with Pool(processes=8) as pool:
for result in tqdm(pool.imap(process_audio, audio_files), total=len(audio_files)):
if result is not None:
results.append(f'{result}\n')
# save filelist
with open(filelist_path, 'w', encoding='utf-8') as f:
f.writelines(results)
print(f"filelist file has been saved to {filelist_path}")
# faster and use much less CPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == '__main__':
main()
\ No newline at end of file
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import itertools
from models.model import Vocos
from dataset import VocosDataset
from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from models.loss import feature_loss, generator_loss, discriminator_loss
from utils.audio import LogMelSpectrogram
from config import MelConfig, VocosConfig, TrainConfig
from utils.scheduler import get_cosine_schedule_with_warmup
from utils.load import continue_training
torch.backends.cudnn.benchmark = True
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig):
if vocos_config.input_channels != mel_config.n_mels:
raise ValueError("input_channels and n_mels must be equal.")
if not os.path.exists(train_config.model_save_path):
print(f'Creating {train_config.model_save_path}')
os.makedirs(train_config.model_save_path, exist_ok=True)
def train(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
vocos_config = VocosConfig()
mel_config = MelConfig()
train_config = TrainConfig()
_init_config(vocos_config, mel_config, train_config)
generator = Vocos(vocos_config, mel_config).to(rank)
mpd = MultiPeriodDiscriminator().to(rank)
mrd = MultiResolutionDiscriminator().to(rank)
mel_extractor = LogMelSpectrogram(mel_config).to(rank)
generator = DDP(generator, device_ids=[rank])
mpd = DDP(mpd, device_ids=[rank])
mrd = DDP(mrd, device_ids=[rank])
train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False)
if rank == 0:
writer = SummaryWriter(train_config.log_dir)
optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate)
optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate)
scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
# load latest checkpoints if possible
current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g)
generator.train()
mpd.train()
mrd.train()
for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times
train_dataloader.sampler.set_epoch(epoch)
if rank == 0:
dataloader = tqdm(train_dataloader)
else:
dataloader = train_dataloader
for batch_idx, datas in enumerate(dataloader):
datas = [data.to(rank, non_blocking=True) for data in datas]
audios, mels = datas
audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size]
mels_fake = mel_extractor(audios_fake).squeeze(1) # shape: [batch_size, n_mels, segment_size // hop_length]
optimizer_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000)
optimizer_d.step()
scheduler_d.step()
# generator
optimizer_g.zero_grad()
loss_mel = torch.nn.functional.l1_loss(mels, mels_fake) * 45
# MPD loss
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
# MRD loss
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000)
optimizer_g.step()
scheduler_g.step()
if rank == 0 and batch_idx % train_config.log_interval == 0:
steps = epoch * len(dataloader) + batch_idx
writer.add_scalar("training/gen_loss_total", loss_gen_all, steps)
writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
writer.add_scalar("training/mel_loss", loss_mel.item(), steps)
writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps)
writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps)
writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps)
writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
if rank == 0:
torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt'))
torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt'))
torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt'))
torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt'))
torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt'))
print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}")
cleanup()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
\ No newline at end of file
from dataclasses import dataclass, asdict
import torch
from torch import Tensor
import torch.nn as nn
import torchaudio
import torchaudio.transforms
from config import MelConfig
class LogMelSpectrogram(nn.Module):
def __init__(self, config: MelConfig):
super().__init__()
self.spec = torchaudio.transforms.MelSpectrogram(**asdict(config))
def forward(self, x: Tensor) -> Tensor:
return self.compress(self.spec(x))
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
\ No newline at end of file
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def continue_training(checkpoint_path, generator: DDP, mpd: DDP, mrd: DDP, optimizer_d: optim.Optimizer, optimizer_g: optim.Optimizer) -> int:
"""load the latest checkpoints and optimizers"""
generator_dict = {}
mpd_dict = {}
mrd_dict = {}
optimizer_d_dict = {}
optimizer_g_dict = {}
# globt all the checkpoints in the directory
for file in os.listdir(checkpoint_path):
if file.endswith(".pt"):
name, epoch_str = file.rsplit('_', 1)
epoch = int(epoch_str.split('.')[0])
if name.startswith("generator"):
generator_dict[epoch] = file
elif name.startswith("mpd"):
mpd_dict[epoch] = file
elif name.startswith("mrd"):
mrd_dict[epoch] = file
elif name.startswith("optimizerd"):
optimizer_d_dict[epoch] = file
elif name.startswith("optimizerg"):
optimizer_g_dict[epoch] = file
# get the largest epoch
common_epochs = set(generator_dict.keys()) & set(mpd_dict.keys()) & set(mrd_dict.keys()) & set(optimizer_d_dict.keys()) & set(optimizer_g_dict.keys())
if common_epochs:
max_epoch = max(common_epochs)
generator_path = os.path.join(checkpoint_path, generator_dict[max_epoch])
mpd_path = os.path.join(checkpoint_path, mpd_dict[max_epoch])
mrd_path = os.path.join(checkpoint_path, mrd_dict[max_epoch])
optimizer_d_path = os.path.join(checkpoint_path, optimizer_d_dict[max_epoch])
optimizer_g_path = os.path.join(checkpoint_path, optimizer_g_dict[max_epoch])
# load model and optimizer
generator.module.load_state_dict(torch.load(generator_path, map_location='cpu'))
mpd.module.load_state_dict(torch.load(mpd_path, map_location='cpu'))
mrd.module.load_state_dict(torch.load(mrd_path, map_location='cpu'))
optimizer_d.load_state_dict(torch.load(optimizer_d_path, map_location='cpu'))
optimizer_g.load_state_dict(torch.load(optimizer_g_path, map_location='cpu'))
print(f'resume model and optimizer from {max_epoch} epoch')
return max_epoch + 1
else:
return 0
\ No newline at end of file
# modified from transformers.optimization
import math
from functools import partial
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
def _get_constant_lambda(_=None):
return 1
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
"""
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
kwargs (`dict`, *optional*):
Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
for possible parameters.
Return:
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
"""
return ReduceLROnPlateau(optimizer, **kwargs)
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
return 1.0
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_linear_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`int`, *optional*, defaults to 1):
The number of hard restarts to use.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
lr_end: float,
power: float,
lr_init: int,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step > num_training_steps:
return lr_end / lr_init # as LambdaLR multiplies by lr_init
else:
lr_range = lr_init - lr_end
decay_steps = num_training_steps - num_warmup_steps
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
decay = lr_range * pct_remaining**power + lr_end
return decay / lr_init # as LambdaLR multiplies by lr_init
def get_polynomial_decay_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
):
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
lr_end (`float`, *optional*, defaults to 1e-7):
The end LR.
power (`float`, *optional*, defaults to 1.0):
Power factor.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
implementation at
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_init = optimizer.defaults["lr"]
if not (lr_init > lr_end):
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
lr_lambda = partial(
_get_polynomial_decay_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
lr_end=lr_end,
power=power,
lr_init=lr_init,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
shift = timescale - num_warmup_steps
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
return decay
def get_inverse_sqrt_schedule(
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
):
"""
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
Time scale.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
# Note: this implementation is adapted from
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
if timescale is None:
timescale = num_warmup_steps
lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
\ No newline at end of file
import os
os.environ['TMPDIR'] = './temps' # avoid the system default temp folder not having access permissions
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # use huggingfacae mirror for users that could not login to huggingface
from dataclasses import asdict
from text import symbols
import torch
import torchaudio
from utils.audio import LogMelSpectrogram
from config import ModelConfig, VocosConfig, MelConfig
from models.model import StableTTS
from vocos_pytorch.models.model import Vocos
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2
from text import cleaned_text_to_sequence
from datas.dataset import intersperse
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
device = 'cuda' if torch.cuda.is_available() else 'cpu'
g2p_mapping = {
'chinese': chinese_to_cnm3,
'japanese': japanese_to_ipa2,
'english': english_to_ipa2,
}
@ torch.inference_mode()
def inference(text: str, ref_audio: torch.Tensor, language: str, checkpoint_path: str, step: int=10) -> torch.Tensor:
global last_checkpoint_path
if checkpoint_path != last_checkpoint_path:
tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
last_checkpoint_path = checkpoint_path
phonemizer = g2p_mapping.get(language)
# prepare input for tts model
x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
waveform, sr = torchaudio.load(ref_audio)
if sr != sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
y = mel_extractor(waveform).to(device)
# inference
mel = tts_model.synthesise(x, x_len, step, y=y, temperature=1, length_scale=1)['decoder_outputs']
audio = vocoder(mel)
# process output for gradio
audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
return audio_output, mel_output
def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
mel_extractor = LogMelSpectrogram(mel_config)
vocoder = Vocos(vocoder_config, mel_config)
# tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
tts_model.to(device)
tts_model.eval()
vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
vocoder.to(device)
vocoder.eval()
return tts_model, mel_extractor, vocoder
def plot_mel_spectrogram(mel_spectrogram):
plt.close() # prevent memory leak
fig, ax = plt.subplots(figsize=(20, 8))
ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
plt.axis('off')
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
return fig
def main():
tts_model_config = ModelConfig()
mel_config = MelConfig()
vocoder_config = VocosConfig()
tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
vocoder_checkpoint_path = './checkpoints/vocoder.pt'
global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
sample_rate = mel_config.sample_rate
last_checkpoint_path = None
tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
# gradio wabui
gui_title = 'StableTTS'
gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(f"# {gui_title}")
gr.Markdown(gui_description)
with gr.Row():
with gr.Column():
input_text_gr = gr.Textbox(
label="Input Text",
info="One or two sentences at a time is better. Up to 200 text characters.",
value="你好,世界!",
)
ref_audio_gr = gr.Audio(
label="Reference Speaker",
type="filepath"
)
language_gr = gr.Dropdown(
label='Language',
choices=list(g2p_mapping.keys()),
value = 'chinese'
)
checkpoint_gr = gr.Dropdown(
label='checkpoint',
choices=tts_checkpoint_path,
value = tts_checkpoint_path[0]
)
step_gr = gr.Slider(
label='Step',
minimum=1,
maximum=100,
value=25,
step=1
)
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
with gr.Column():
mel_gr = gr.Plot(label="Mel Visual")
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
tts_button.click(inference, [input_text_gr, ref_audio_gr, language_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])
demo.queue()
demo.launch(debug=True, show_api=True)
if __name__ == '__main__':
main()
\ No newline at end of file
File added
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment