"examples/vscode:/vscode.git/clone" did not exist on "73179b764aae0d28b7bea32109dca9e08632c6d9"
Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
import indextts.BigVGAN.activations as activations
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
from indextts.BigVGAN.utils import get_padding, init_weights
LRELU_SLOPE = 0.1
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if self.h.get("use_cuda_kernel", False):
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
else:
from indextts.BigVGAN.alias_free_torch import Activation1d
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers
if self.h.get("use_cuda_kernel", False):
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
else:
from indextts.BigVGAN.alias_free_torch import Activation1d
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class BigVGAN(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h, use_cuda_kernel=False):
"""
Args:
h (dict)
use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation
"""
super(BigVGAN, self).__init__()
self.h = h
self.h["use_cuda_kernel"] = use_cuda_kernel
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.feat_upsample = h.feat_upsample
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
# pre conv
self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3))
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(nn.ModuleList([
weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2))
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation))
if use_cuda_kernel:
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
else:
from indextts.BigVGAN.alias_free_torch import Activation1d
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim)
self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1)
if self.cond_in_each_up_layer:
self.conds = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1))
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, x, mel_ref, lens=None):
speaker_embedding = []
for mel_ref_ in mel_ref:
speaker_embedding_ = self.speaker_encoder(mel_ref_, lens)
speaker_embedding.append(speaker_embedding_)
speaker_embedding = torch.stack(speaker_embedding).sum(dim=0)
speaker_embedding = speaker_embedding / len(mel_ref)
n_batch = x.size(0)
contrastive_loss = None
if n_batch * 2 == speaker_embedding.size(0):
spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :]
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
speaker_embedding = speaker_embedding[:n_batch, :, :]
speaker_embedding = speaker_embedding.transpose(1, 2)
# upsample feat
if self.feat_upsample:
x = torch.nn.functional.interpolate(
x.transpose(1, 2),
scale_factor=[4],
mode="linear",
).squeeze(1)
else:
x = x.transpose(1, 2)
### bigVGAN ###
# pre conv
x = self.conv_pre(x)
x = x + self.cond_layer(speaker_embedding)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
if self.cond_in_each_up_layer:
x = x + self.conds[i](speaker_embedding)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x, contrastive_loss
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
def cal_clip_loss(self, image_features, text_features, logit_scale):
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = torch.arange(logits_per_image.shape[0], device=device, dtype=torch.long)
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss
def get_logits(self, image_features, text_features, logit_scale):
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
class DiscriminatorP(torch.nn.Module):
def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.d_mult = h.discriminator_channel_mult
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, h):
super(MultiPeriodDiscriminator, self).__init__()
self.mpd_reshapes = h.mpd_reshapes
print("mpd_reshapes: {}".format(self.mpd_reshapes))
discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
self.discriminators = nn.ModuleList(discriminators)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(self, cfg, resolution):
super().__init__()
self.resolution = resolution
assert len(self.resolution) == 3, \
"MRD layer requires list with len=3, got {}".format(self.resolution)
self.lrelu_slope = LRELU_SLOPE
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
self.d_mult = cfg.mrd_channel_mult
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
])
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x):
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x):
n_fft, hop_length, win_length = self.resolution
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
x = x.squeeze(1)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
x = torch.view_as_real(x) # [B, F, TT, 2]
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
return mag
class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert len(self.resolutions) == 3, \
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
format(self.resolutions)
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses
"""Library implementing convolutional neural networks.
Authors
* Mirco Ravanelli 2020
* Jianyuan Zhong 2020
* Cem Subakan 2021
* Davide Borra 2021
* Andreas Nautsch 2022
* Sarthak Yadav 2022
"""
import logging
import math
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
class SincConv(nn.Module):
"""This function implements SincConv (SincNet).
M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with
SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158)
Arguments
---------
out_channels : int
It is the number of output channels.
kernel_size: int
Kernel size of the convolutional filters.
input_shape : tuple
The shape of the input. Alternatively use ``in_channels``.
in_channels : int
The number of input channels. Alternatively use ``input_shape``.
stride : int
Stride factor of the convolutional filters. When the stride factor > 1,
a decimation in time is performed.
dilation : int
Dilation factor of the convolutional filters.
padding : str
(same, valid, causal). If "valid", no padding is performed.
If "same" and stride is 1, output shape is the same as the input shape.
"causal" results in causal (dilated) convolutions.
padding_mode : str
This flag specifies the type of padding. See torch.nn documentation
for more information.
sample_rate : int
Sampling rate of the input signals. It is only used for sinc_conv.
min_low_hz : float
Lowest possible frequency (in Hz) for a filter. It is only used for
sinc_conv.
min_band_hz : float
Lowest possible value (in Hz) for a filter bandwidth.
Example
-------
>>> inp_tensor = torch.rand([10, 16000])
>>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11)
>>> out_tensor = conv(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 16000, 25])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding="same",
padding_mode="reflect",
sample_rate=16000,
min_low_hz=50,
min_band_hz=50,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
# input shape inference
if input_shape is None and self.in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if self.in_channels is None:
self.in_channels = self._check_input_shape(input_shape)
if self.out_channels % self.in_channels != 0:
raise ValueError(
"Number of output channels must be divisible by in_channels"
)
# Initialize Sinc filters
self._init_sinc_conv()
def forward(self, x):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
Returns
-------
wx : torch.Tensor
The convolved outputs.
"""
x = x.transpose(1, -1)
self.device = x.device
unsqueeze = x.ndim == 2
if unsqueeze:
x = x.unsqueeze(1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got %s."
% (self.padding)
)
sinc_filters = self._get_sinc_filters()
wx = F.conv1d(
x,
sinc_filters,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=self.in_channels,
)
if unsqueeze:
wx = wx.squeeze(1)
wx = wx.transpose(1, -1)
return wx
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
in_channels = 1
elif len(shape) == 3:
in_channels = shape[-1]
else:
raise ValueError(
"sincconv expects 2d or 3d inputs. Got " + str(len(shape))
)
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
return in_channels
def _get_sinc_filters(self):
"""This functions creates the sinc-filters to used for sinc-conv."""
# Computing the low frequencies of the filters
low = self.min_low_hz + torch.abs(self.low_hz_)
# Setting minimum band and minimum freq
high = torch.clamp(
low + self.min_band_hz + torch.abs(self.band_hz_),
self.min_low_hz,
self.sample_rate / 2,
)
band = (high - low)[:, 0]
# Passing from n_ to the corresponding f_times_t domain
self.n_ = self.n_.to(self.device)
self.window_ = self.window_.to(self.device)
f_times_t_low = torch.matmul(low, self.n_)
f_times_t_high = torch.matmul(high, self.n_)
# Left part of the filters.
band_pass_left = (
(torch.sin(f_times_t_high) - torch.sin(f_times_t_low))
/ (self.n_ / 2)
) * self.window_
# Central element of the filter
band_pass_center = 2 * band.view(-1, 1)
# Right part of the filter (sinc filters are symmetric)
band_pass_right = torch.flip(band_pass_left, dims=[1])
# Combining left, central, and right part of the filter
band_pass = torch.cat(
[band_pass_left, band_pass_center, band_pass_right], dim=1
)
# Amplitude normalization
band_pass = band_pass / (2 * band[:, None])
# Setting up the filter coefficients
filters = band_pass.view(self.out_channels, 1, self.kernel_size)
return filters
def _init_sinc_conv(self):
"""Initializes the parameters of the sinc_conv layer."""
# Initialize filterbanks such that they are equally spaced in Mel scale
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = torch.linspace(
self._to_mel(self.min_low_hz),
self._to_mel(high_hz),
self.out_channels + 1,
)
hz = self._to_hz(mel)
# Filter lower frequency and bands
self.low_hz_ = hz[:-1].unsqueeze(1)
self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1)
# Maiking freq and bands learnable
self.low_hz_ = nn.Parameter(self.low_hz_)
self.band_hz_ = nn.Parameter(self.band_hz_)
# Hamming window
n_lin = torch.linspace(
0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
)
self.window_ = 0.54 - 0.46 * torch.cos(
2 * math.pi * n_lin / self.kernel_size
)
# Time axis (only half is needed due to symmetry)
n = (self.kernel_size - 1) / 2.0
self.n_ = (
2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
)
def _to_mel(self, hz):
"""Converts frequency in Hz to the mel scale."""
return 2595 * np.log10(1 + hz / 700)
def _to_hz(self, mel):
"""Converts frequency in the mel scale to Hz."""
return 700 * (10 ** (mel / 2595) - 1)
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
"""This function performs zero-padding on the time axis
such that their lengths is unchanged after the convolution.
Arguments
---------
x : torch.Tensor
Input tensor.
kernel_size : int
Size of kernel.
dilation : int
Dilation used.
stride : int
Stride.
Returns
-------
x : torch.Tensor
"""
# Detecting input shape
L_in = self.in_channels
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
class Conv1d(nn.Module):
"""This function implements 1d convolution.
Arguments
---------
out_channels : int
It is the number of output channels.
kernel_size : int
Kernel size of the convolutional filters.
input_shape : tuple
The shape of the input. Alternatively use ``in_channels``.
in_channels : int
The number of input channels. Alternatively use ``input_shape``.
stride : int
Stride factor of the convolutional filters. When the stride factor > 1,
a decimation in time is performed.
dilation : int
Dilation factor of the convolutional filters.
padding : str
(same, valid, causal). If "valid", no padding is performed.
If "same" and stride is 1, output shape is the same as the input shape.
"causal" results in causal (dilated) convolutions.
groups : int
Number of blocked connections from input channels to output channels.
bias : bool
Whether to add a bias term to convolution operation.
padding_mode : str
This flag specifies the type of padding. See torch.nn documentation
for more information.
skip_transpose : bool
If False, uses batch x time x channel convention of speechbrain.
If True, uses batch x channel x time convention.
weight_norm : bool
If True, use weight normalization,
to be removed with self.remove_weight_norm() at inference
conv_init : str
Weight initialization for the convolution network
default_padding: str or int
This sets the default padding mode that will be used by the pytorch Conv1d backend.
Example
-------
>>> inp_tensor = torch.rand([10, 40, 16])
>>> cnn_1d = Conv1d(
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
... )
>>> out_tensor = cnn_1d(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 40, 8])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
skip_transpose=False,
weight_norm=False,
conv_init=None,
default_padding=0,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.in_channels = in_channels
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=default_padding,
groups=groups,
bias=bias,
)
if conv_init == "kaiming":
nn.init.kaiming_normal_(self.conv.weight)
elif conv_init == "zero":
nn.init.zeros_(self.conv.weight)
elif conv_init == "normal":
nn.init.normal_(self.conv.weight, std=1e-6)
if weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
def forward(self, x):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
Returns
-------
wx : torch.Tensor
The convolved outputs.
"""
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding
)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
"""This function performs zero-padding on the time axis
such that their lengths is unchanged after the convolution.
Arguments
---------
x : torch.Tensor
Input tensor.
kernel_size : int
Size of kernel.
dilation : int
Dilation used.
stride : int
Stride.
Returns
-------
x : torch.Tensor
The padded outputs.
"""
# Detecting input shape
L_in = self.in_channels
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError(
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
)
# Kernel size must be odd
if not self.padding == "valid" and self.kernel_size % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
return in_channels
def remove_weight_norm(self):
"""Removes weight normalization at inference if used during training."""
self.conv = nn.utils.remove_weight_norm(self.conv)
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
"""This function computes the number of elements to add for zero-padding.
Arguments
---------
L_in : int
stride: int
kernel_size : int
dilation : int
Returns
-------
padding : int
The size of the padding to be added
"""
if stride > 1:
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
else:
L_out = (
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
)
padding = [
math.floor((L_in - L_out) / 2),
math.floor((L_in - L_out) / 2),
]
return padding
"""Library implementing linear transformation.
Authors
* Mirco Ravanelli 2020
* Davide Borra 2021
"""
import logging
import torch
import torch.nn as nn
class Linear(torch.nn.Module):
"""Computes a linear transformation y = wx + b.
Arguments
---------
n_neurons : int
It is the number of output neurons (i.e, the dimensionality of the
output).
input_shape : tuple
It is the shape of the input tensor.
input_size : int
Size of the input tensor.
bias : bool
If True, the additive bias b is adopted.
max_norm : float
weight max-norm.
combine_dims : bool
If True and the input is 4D, combine 3rd and 4th dimensions of input.
Example
-------
>>> inputs = torch.rand(10, 50, 40)
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
>>> output = lin_t(inputs)
>>> output.shape
torch.Size([10, 50, 100])
"""
def __init__(
self,
n_neurons,
input_shape=None,
input_size=None,
bias=True,
max_norm=None,
combine_dims=False,
):
super().__init__()
self.max_norm = max_norm
self.combine_dims = combine_dims
if input_shape is None and input_size is None:
raise ValueError("Expected one of input_shape or input_size")
if input_size is None:
input_size = input_shape[-1]
if len(input_shape) == 4 and self.combine_dims:
input_size = input_shape[2] * input_shape[3]
# Weights are initialized following pytorch approach
self.w = nn.Linear(input_size, n_neurons, bias=bias)
def forward(self, x):
"""Returns the linear transformation of input tensor.
Arguments
---------
x : torch.Tensor
Input to transform linearly.
Returns
-------
wx : torch.Tensor
The linearly transformed outputs.
"""
if x.ndim == 4 and self.combine_dims:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
if self.max_norm is not None:
self.w.weight.data = torch.renorm(
self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm
)
wx = self.w(x)
return wx
This diff is collapsed.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import glob
import os
import matplotlib
import matplotlib.pylab as plt
import torch
from scipy.io.wavfile import write
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
MAX_WAV_VALUE = 32768.0
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(
spectrogram,
aspect="auto",
origin="lower",
interpolation="none",
vmin=1e-6,
vmax=clip_max,
)
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj):
print(f"Saving checkpoint to {filepath}")
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
# Fallback to original scanning logic first
pattern = os.path.join(cp_dir, prefix + "????????")
cp_list = glob.glob(pattern)
if len(cp_list) > 0:
last_checkpoint_path = sorted(cp_list)[-1]
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
return last_checkpoint_path
# If no pattern-based checkpoints are found, check for renamed file
if renamed_file:
renamed_path = os.path.join(cp_dir, renamed_file)
if os.path.isfile(renamed_path):
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
return renamed_path
return None
def save_audio(audio, path, sr):
# wav: torch with 1d shape
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
write(path, sr, audio)
import os
import sys
import warnings
# Suppress warnings from tensorflow and other libraries
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
def main():
import argparse
parser = argparse.ArgumentParser(description="IndexTTS Command Line")
parser.add_argument("text", type=str, help="Text to be synthesized")
parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 for inference if available")
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps)." )
args = parser.parse_args()
if len(args.text.strip()) == 0:
print("ERROR: Text is empty.")
parser.print_help()
sys.exit(1)
if not os.path.exists(args.voice):
print(f"Audio prompt file {args.voice} does not exist.")
parser.print_help()
sys.exit(1)
if not os.path.exists(args.config):
print(f"Config file {args.config} does not exist.")
parser.print_help()
sys.exit(1)
output_path = args.output_path
if os.path.exists(output_path):
if not args.force:
print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
parser.print_help()
sys.exit(1)
else:
os.remove(output_path)
try:
import torch
except ImportError:
print("ERROR: PyTorch is not installed. Please install it first.")
sys.exit(1)
if args.device is None:
if torch.cuda.is_available():
args.device = "cuda:0"
elif torch.mps.is_available():
args.device = "mps"
else:
args.device = "cpu"
args.fp16 = False # Disable FP16 on CPU
print("WARNING: Running on CPU may be slow.")
from indextts.infer import IndexTTS
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, is_fp16=args.fp16, device=args.device)
tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment