Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import random
from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.sampler import (
BatchSampler,
RandomSampler,
Sampler,
SequentialSampler,
)
class ScheduledSampler(Sampler):
"""A sampler that samples data from a given concat-dataset.
Args:
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
batch_size (int): batch size
holistic_shuffle (bool): whether to shuffle the whole dataset or not
logger (logging.Logger): logger to print warning message
Usage:
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
[3, 4, 5, 0, 1, 2, 6, 7, 8]
"""
def __init__(
self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
):
if not isinstance(concat_dataset, ConcatDataset):
raise ValueError(
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
type(concat_dataset)
)
)
if not isinstance(batch_size, int):
raise ValueError(
"batch_size must be an integer, but got {}".format(type(batch_size))
)
if not isinstance(holistic_shuffle, bool):
raise ValueError(
"holistic_shuffle must be a boolean, but got {}".format(
type(holistic_shuffle)
)
)
self.concat_dataset = concat_dataset
self.batch_size = batch_size
self.holistic_shuffle = holistic_shuffle
affected_dataset_name = []
affected_dataset_len = []
for dataset in concat_dataset.datasets:
dataset_len = len(dataset)
dataset_name = dataset.get_dataset_name()
if dataset_len < batch_size:
affected_dataset_name.append(dataset_name)
affected_dataset_len.append(dataset_len)
self.type = type
for dataset_name, dataset_len in zip(
affected_dataset_name, affected_dataset_len
):
if not type == "valid":
logger.warning(
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
type, dataset_name, dataset_len, batch_size
)
)
def __len__(self):
# the number of batches with drop last
num_of_batches = sum(
[
math.floor(len(dataset) / self.batch_size)
for dataset in self.concat_dataset.datasets
]
)
return num_of_batches * self.batch_size
def __iter__(self):
iters = []
for dataset in self.concat_dataset.datasets:
iters.append(
SequentialSampler(dataset).__iter__()
if self.holistic_shuffle
else RandomSampler(dataset).__iter__()
)
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
output_batches = []
for dataset_idx in range(len(self.concat_dataset.datasets)):
cur_batch = []
for idx in iters[dataset_idx]:
cur_batch.append(idx + init_indices[dataset_idx])
if len(cur_batch) == self.batch_size:
output_batches.append(cur_batch)
cur_batch = []
if self.type == "valid" and len(cur_batch) > 0:
output_batches.append(cur_batch)
cur_batch = []
# force drop last in training
random.shuffle(output_batches)
output_indices = [item for sublist in output_batches for item in sublist]
return iter(output_indices)
def build_samplers(concat_dataset: Dataset, cfg, logger, type):
sampler = ScheduledSampler(
concat_dataset,
cfg.train.batch_size,
cfg.train.sampler.holistic_shuffle,
logger,
type,
)
batch_sampler = BatchSampler(
sampler,
cfg.train.batch_size,
cfg.train.sampler.drop_last if not type == "valid" else False,
)
return sampler, batch_sampler
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
from pathlib import Path
import re
import accelerate
import json5
import numpy as np
import torch
from accelerate.utils import ProjectConfiguration
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.codec.codec_sampler import build_samplers
class CodecTrainer:
def __init__(self):
super().__init__()
def _init_accelerator(self):
"""Initialize the accelerator components."""
self.exp_dir = os.path.join(
os.path.abspath(self.cfg.log_dir), self.args.exp_name
)
project_config = ProjectConfiguration(
project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
)
self.accelerator = accelerate.Accelerator(
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
log_with=self.cfg.train.tracker,
project_config=project_config,
)
if self.accelerator.is_main_process:
os.makedirs(project_config.project_dir, exist_ok=True)
os.makedirs(project_config.logging_dir, exist_ok=True)
with self.accelerator.main_process_first():
self.accelerator.init_trackers(self.args.exp_name)
def _build_dataset(self):
pass
def _build_criterion(self):
pass
def _build_model(self):
pass
def _build_dataloader(self):
"""Build dataloader which merges a series of datasets."""
# Build dataset instance for each dataset and combine them by ConcatDataset
Dataset, Collator = self._build_dataset()
# Build train set
train_dataset = Dataset(self.cfg, self.cfg.dataset, is_valid=False)
train_collate = Collator(self.cfg)
sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=self.accelerator.num_processes,
rank=self.accelerator.local_process_index,
shuffle=True,
seed=self.cfg.train.random_seed,
)
train_loader = DataLoader(
train_dataset,
batch_size=self.cfg.train.batch_size,
collate_fn=train_collate,
sampler=sampler,
num_workers=self.cfg.train.dataloader.num_worker,
pin_memory=self.cfg.train.dataloader.pin_memory,
)
return train_loader, None
def _build_optimizer(self):
pass
def _build_scheduler(self):
pass
def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
"""Load model from checkpoint. If a folder is given, it will
load the latest checkpoint in checkpoint_dir. If a path is given
it will load the checkpoint specified by checkpoint_path.
**Only use this method after** ``accelerator.prepare()``.
"""
if checkpoint_path is None:
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
checkpoint_path = ls[0]
if resume_type == "resume":
self.accelerator.load_state(checkpoint_path)
elif resume_type == "finetune":
accelerate.load_checkpoint_and_dispatch(
self.accelerator.unwrap_model(self.model),
os.path.join(checkpoint_path, "pytorch_model.bin"),
)
self.logger.info("Load model weights for finetune SUCCESS!")
else:
raise ValueError("Unsupported resume type: {}".format(resume_type))
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
return checkpoint_path
def train_loop(self):
pass
def _train_epoch(self):
pass
def _valid_epoch(self):
pass
def _train_step(self):
pass
def _valid_step(self):
pass
def _inference(self):
pass
def _set_random_seed(self, seed):
"""Set random seed for all possible random modules."""
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def _check_nan(self, loss):
if torch.any(torch.isnan(loss)):
self.logger.fatal("Fatal Error: NaN!")
self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
def _check_basic_configs(self):
if self.cfg.train.gradient_accumulation_step <= 0:
self.logger.fatal("Invalid gradient_accumulation_step value!")
self.logger.error(
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
)
self.accelerator.end_training()
raise ValueError(
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
)
def _count_parameters(self):
pass
def _dump_cfg(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
json5.dump(
self.cfg,
open(path, "w"),
indent=4,
sort_keys=True,
ensure_ascii=False,
quote_keys=True,
)
def _is_valid_pattern(self, directory_name):
directory_name = str(directory_name)
pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
return re.match(pattern, directory_name) is not None
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
from .filter import *
from .resample import *
from .act import *
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if "sinc" in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2
# For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
# input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
return out
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
xx = self.lowpass(x)
return xx
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import random
import numpy as np
import torchaudio
import librosa
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from utils.data_utils import *
from models.codec.codec_dataset import CodecDataset
class FAcodecDataset(torch.utils.data.Dataset):
def __init__(self, cfg, dataset, is_valid=False):
"""
Args:
cfg: config
dataset: dataset name
is_valid: whether to use train or valid dataset
"""
self.data_root_dir = cfg.dataset
self.data_list = []
# walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
for root, _, files in os.walk(self.data_root_dir):
for file in files:
if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")):
self.data_list.append(os.path.join(root, file))
self.sr = cfg.preprocess_params.sr
self.duration_range = cfg.preprocess_params.duration_range
self.to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=cfg.preprocess_params.spect_params.n_mels,
n_fft=cfg.preprocess_params.spect_params.n_fft,
win_length=cfg.preprocess_params.spect_params.win_length,
hop_length=cfg.preprocess_params.spect_params.hop_length,
)
self.mean, self.std = -4, 4
def preprocess(self, wave):
wave_tensor = (
torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave
)
mel_tensor = self.to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std
return mel_tensor
def __len__(self):
# return len(self.data_list)
return len(self.data_list) # return a fixed number for testing
def __getitem__(self, index):
wave, _ = librosa.load(self.data_list[index], sr=self.sr)
wave = np.random.randn(self.sr * random.randint(*self.duration_range))
wave = wave / np.max(np.abs(wave))
mel = self.preprocess(wave).squeeze(0)
wave = torch.from_numpy(wave).float()
return wave, mel
class FAcodecCollator(object):
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
self.cfg = cfg
def __call__(self, batch):
# batch[0] = wave, mel, text, f0, speakerid
batch_size = len(batch)
# sort by mel length
lengths = [b[1].shape[1] for b in batch]
batch_indexes = np.argsort(lengths)[::-1]
batch = [batch[bid] for bid in batch_indexes]
nmels = batch[0][1].size(0)
max_mel_length = max([b[1].shape[1] for b in batch])
max_wave_length = max([b[0].size(0) for b in batch])
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
waves = torch.zeros((batch_size, max_wave_length)).float()
mel_lengths = torch.zeros(batch_size).long()
wave_lengths = torch.zeros(batch_size).long()
for bid, (wave, mel) in enumerate(batch):
mel_size = mel.size(1)
mels[bid, :, :mel_size] = mel
waves[bid, : wave.size(0)] = wave
mel_lengths[bid] = mel_size
wave_lengths[bid] = wave.size(0)
return waves, mels, wave_lengths, mel_lengths
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import shutil
import warnings
import argparse
import torch
import os
import yaml
warnings.simplefilter("ignore")
from .modules.commons import *
import time
import torchaudio
import librosa
from collections import OrderedDict
class FAcodecInference(object):
def __init__(self, args=None, cfg=None):
self.args = args
self.cfg = cfg
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._build_model()
self._load_checkpoint()
def _build_model(self):
model = build_model(self.cfg.model_params)
_ = [model[key].to(self.device) for key in model]
return model
def _load_checkpoint(self):
sd = torch.load(self.args.checkpoint_path, map_location="cpu")
sd = sd["net"] if "net" in sd else sd
new_params = dict()
for key, state_dict in sd.items():
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:]
new_state_dict[k] = v
new_params[key] = new_state_dict
for key in new_params:
if key in self.model:
self.model[key].load_state_dict(new_params[key])
_ = [self.model[key].eval() for key in self.model]
@torch.no_grad()
def inference(self, source, output_dir):
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
(
z,
quantized,
commitment_loss,
codebook_loss,
timbre,
codes,
) = self.model.quantizer(
z,
source_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
return_codes=True,
)
full_pred_wave = self.model.decoder(z)
os.makedirs(output_dir, exist_ok=True)
source_name = source.split("/")[-1].split(".")[0]
torchaudio.save(
f"{output_dir}/reconstructed_{source_name}.wav",
full_pred_wave[0].cpu(),
self.cfg.preprocess_params.sr,
)
print(
"Reconstructed audio saved as: ",
f"{output_dir}/reconstructed_{source_name}.wav",
)
return quantized, codes
@torch.no_grad()
def voice_conversion(self, source, reference, output_dir):
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
reference_audio = (
torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
)
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
z,
source_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
)
z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
(
z_ref,
quantized_ref,
commitment_loss_ref,
codebook_loss_ref,
timbre_ref,
) = self.model.quantizer(
z_ref,
reference_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
)
z_conv = self.model.quantizer.voice_conversion(
quantized[0] + quantized[1],
reference_audio[None, ...].to(self.device).float(),
)
full_pred_wave = self.model.decoder(z_conv)
os.makedirs(output_dir, exist_ok=True)
source_name = source.split("/")[-1].split(".")[0]
reference_name = reference.split("/")[-1].split(".")[0]
torchaudio.save(
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
full_pred_wave[0].cpu(),
self.cfg.preprocess_params.sr,
)
print(
"Voice conversion results saved as: ",
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
)
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import random
from pathlib import Path
import re
import glob
import accelerate
import json
import numpy as np
import torch
from accelerate.utils import ProjectConfiguration
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchaudio
from accelerate.logging import get_logger
from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator
from models.codec.codec_sampler import build_samplers
from models.codec.codec_trainer import CodecTrainer
from modules.dac.nn.loss import (
MultiScaleSTFTLoss,
MelSpectrogramLoss,
GANLoss,
L1Loss,
FocalLoss,
)
from audiotools import AudioSignal
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
try:
import nemo.collections.asr as nemo_asr
except ImportError:
print(
"Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
)
nemo_asr = None
from models.codec.facodec.modules.commons import (
build_model,
load_checkpoint,
load_F0_models,
log_norm,
)
from models.codec.facodec.optimizer import build_optimizer
class FAcodecTrainer(CodecTrainer):
def __init__(self, args, cfg):
super().__init__()
self.args = args
self.cfg = cfg
cfg.exp_name = args.exp_name
# Init accelerator
self._init_accelerator()
self.accelerator.wait_for_everyone()
# Init logger
with self.accelerator.main_process_first():
self.logger = get_logger(args.exp_name, log_level=args.log_level)
self.logger.info("=" * 56)
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
self.logger.info("=" * 56)
self.logger.info("\n")
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
self.logger.info(f"Experiment name: {args.exp_name}")
self.logger.info(f"Experiment directory: {self.exp_dir}")
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
if self.accelerator.is_main_process:
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
# Init training status
self.batch_count: int = 0
self.step: int = 0
self.epoch: int = 0
self.max_epoch = (
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
)
self.logger.info(
"Max epoch: {}".format(
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
)
)
# Check potential erorrs
if self.accelerator.is_main_process:
self._check_basic_configs()
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
self.checkpoints_path = [
[] for _ in range(len(self.save_checkpoint_stride))
]
self.run_eval = self.cfg.train.run_eval
# Set random seed
with self.accelerator.main_process_first():
start = time.monotonic_ns()
self._set_random_seed(self.cfg.train.random_seed)
end = time.monotonic_ns()
self.logger.debug(
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
)
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
# Build dataloader
with self.accelerator.main_process_first():
self.logger.info("Building dataset...")
start = time.monotonic_ns()
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
end = time.monotonic_ns()
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
# Build model
with self.accelerator.main_process_first():
self.logger.info("Building model...")
start = time.monotonic_ns()
self.model = self._build_model()
end = time.monotonic_ns()
for _, model in self.model.items():
self.logger.debug(model)
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
# Build optimizers and schedulers
with self.accelerator.main_process_first():
self.logger.info("Building optimizer and scheduler...")
start = time.monotonic_ns()
self.optimizer = self._build_optimizer()
end = time.monotonic_ns()
self.logger.info(
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
)
# Build helper models
with self.accelerator.main_process_first():
self.logger.info("Building helper models...")
start = time.monotonic_ns()
self._built_helper_model()
end = time.monotonic_ns()
self.logger.info(
f"Building helper models done in {(end - start) / 1e6:.2f}ms"
)
# Accelerator preparing
self.logger.info("Initializing accelerate...")
start = time.monotonic_ns()
for k in self.model:
self.model[k] = self.accelerator.prepare(self.model[k])
for k, v in self.optimizer.optimizers.items():
self.optimizer.optimizers[k] = self.accelerator.prepare(
self.optimizer.optimizers[k]
)
self.optimizer.schedulers[k] = self.accelerator.prepare(
self.optimizer.schedulers[k]
)
end = time.monotonic_ns()
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
# Build criterions
with self.accelerator.main_process_first():
self.logger.info("Building criterion...")
start = time.monotonic_ns()
self.criterions = self._build_criterion()
end = time.monotonic_ns()
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
# Resume checkpoints
with self.accelerator.main_process_first():
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
if args.resume_type:
self.logger.info("Resuming from checkpoint...")
start = time.monotonic_ns()
ckpt_path = Path(args.checkpoint)
if self._is_valid_pattern(ckpt_path.parts[-1]):
ckpt_path = self._load_model(args.checkpoint, args.resume_type)
else:
ckpt_path = self._load_model(
args.checkpoint, resume_type=args.resume_type
)
end = time.monotonic_ns()
self.logger.info(
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
)
self.checkpoints_path = json.load(
open(os.path.join(ckpt_path, "ckpts.json"), "r")
)
if self.accelerator.is_main_process:
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
# Save config
self.config_save_path = os.path.join(self.exp_dir, "args.json")
def _build_dataset(self):
return FAcodecDataset, FAcodecCollator
def _build_criterion(self):
criterions = dict()
stft_criterion = MultiScaleSTFTLoss()
mel_criterion = MelSpectrogramLoss(
n_mels=[5, 10, 20, 40, 80, 160, 320],
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
mel_fmax=[None, None, None, None, None, None, None],
pow=1.0,
mag_weight=0.0,
clamp_eps=1e-5,
)
content_criterion = FocalLoss(gamma=2)
l1_criterion = L1Loss()
criterions["stft"] = stft_criterion
criterions["mel"] = mel_criterion
criterions["l1"] = l1_criterion
criterions["content"] = content_criterion
return criterions
def _build_model(self):
model = build_model(self.cfg.model_params)
_ = [model[key].to(self.accelerator.device) for key in model]
return model
def _built_helper_model(self):
device = self.accelerator.device
self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device)
# load model and processor
self.w2v_processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
)
self.w2v_model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
).to(device)
self.w2v_model.eval()
if nemo_asr is None:
self.speaker_model = None
else:
self.speaker_model = (
nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
"nvidia/speakerverification_en_titanet_large"
)
)
self.speaker_model = self.speaker_model.to(device)
self.speaker_model.eval()
def _build_optimizer(self):
scheduler_params = {
"warmup_steps": self.cfg.loss_params.warmup_steps,
"base_lr": self.cfg.loss_params.base_lr,
}
optimizer = build_optimizer(
{key: self.model[key] for key in self.model},
scheduler_params_dict={key: scheduler_params.copy() for key in self.model},
lr=float(scheduler_params["base_lr"]),
)
return optimizer
def train_loop(self):
"""Training process"""
self.accelerator.wait_for_everyone()
# Dump config
if self.accelerator.is_main_process:
self._dump_cfg(self.config_save_path)
_ = [self.model[key].train() for key in self.model]
self.optimizer.zero_grad()
# Sync and start training
self.accelerator.wait_for_everyone()
while self.epoch < self.max_epoch:
self.logger.info("\n")
self.logger.info("-" * 32)
self.logger.info("Epoch {}: ".format(self.epoch))
# Train and Validate
train_total_loss, train_losses = self._train_epoch()
for key, loss in train_losses.items():
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
self.accelerator.log(
{"Epoch/Train {} Loss".format(key): loss},
step=self.epoch,
)
self.accelerator.log(
{
"Epoch/Train Total Loss": train_total_loss,
},
step=self.epoch,
)
# Update scheduler
self.accelerator.wait_for_everyone()
# Check save checkpoint interval
run_eval = False
if self.accelerator.is_main_process:
save_checkpoint = False
for i, num in enumerate(self.save_checkpoint_stride):
if self.epoch % num == 0:
save_checkpoint = True
run_eval |= self.run_eval[i]
# Save checkpoints
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process and save_checkpoint:
print("Saving..")
state = {
"net": {key: self.model[key].state_dict() for key in self.model},
"optimizer": self.optimizer.state_dict(),
"scheduler": self.optimizer.scheduler_state_dict(),
"iters": self.step,
"epoch": self.epoch,
}
save_path = os.path.join(
self.checkpoint_dir,
"FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
)
torch.save(state, save_path)
json.dump(
self.checkpoints_path,
open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"),
ensure_ascii=False,
indent=4,
)
self.accelerator.wait_for_everyone()
self.epoch += 1
# Finish training
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
path = os.path.join(
self.checkpoint_dir,
"epoch-{:04d}_step-{:07d}".format(
self.epoch,
self.step,
),
)
print("Saving..")
state = {
"net": {key: self.model[key].state_dict() for key in self.model},
"optimizer": self.optimizer.state_dict(),
"scheduler": self.optimizer.scheduler_state_dict(),
"iters": self.step,
"epoch": self.epoch,
}
save_path = os.path.join(
self.checkpoint_dir,
"FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
)
torch.save(state, save_path)
def _train_epoch(self):
"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
_ = [self.model[key].train() for key in self.model]
epoch_losses: dict = {}
epoch_total_loss: int = 0
for batch in tqdm(
self.train_dataloader,
desc=f"Training Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
# Get losses
total_loss, losses = self._train_step(batch)
self.batch_count += 1
# Log info
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
self.accelerator.log(
{
"Step/Learning Rate": (
self.optimizer.schedulers["encoder"].get_last_lr()[0]
if self.step != 0
else 0
)
},
step=self.step,
)
for key, _ in losses.items():
self.accelerator.log(
{
"Step/Train {} Loss".format(key): losses[key],
},
step=self.step,
)
if not epoch_losses:
epoch_losses = losses
else:
for key, value in losses.items():
epoch_losses[key] += value
epoch_total_loss += total_loss
self.step += 1
# Get and log total losses
self.accelerator.wait_for_everyone()
epoch_total_loss = (
epoch_total_loss
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
for key in epoch_losses.keys():
epoch_losses[key] = (
epoch_losses[key]
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
return epoch_total_loss, epoch_losses
def _train_step(self, data):
"""Training forward step. Should return average loss of a sample over
one batch. Provoke ``_forward_step`` is recommended except for special case.
See ``_train_epoch`` for usage.
"""
# Init losses
train_losses = {}
total_loss = 0
# Use input feature to get predictions
data = [b.to(self.accelerator.device, non_blocking=True) for b in data]
waves, mels, wave_lengths, mel_input_length = data
# extract semantic latent with w2v model
waves_16k = torchaudio.functional.resample(waves, 24000, 16000)
w2v_input = self.w2v_processor(
waves_16k, sampling_rate=16000, return_tensors="pt"
).input_values.to(self.accelerator.device)
with torch.no_grad():
w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits
predicted_ids = torch.argmax(w2v_outputs, dim=-1)
phone_ids = (
F.interpolate(
predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest"
)
.long()
.squeeze(0)
)
# get clips
mel_seg_len = min(
[int(mel_input_length.min().item()), self.cfg.train.max_frame_len]
)
gt_mel_seg = []
wav_seg = []
w2v_seg = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item())
random_start = (
np.random.randint(0, mel_length - mel_seg_len)
if mel_length != mel_seg_len
else 0
)
gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len])
# w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len])
y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300]
wav_seg.append(y.to(self.accelerator.device))
gt_mel_seg = torch.stack(gt_mel_seg).detach()
wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1)
w2v_seg = torch.stack(w2v_seg).float().detach()
with torch.no_grad():
real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach()
F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1))
# normalize f0
# Remove unvoiced frames (replace with -1)
gt_glob_f0s = []
f0_targets = []
for bib in range(len(F0_real)):
voiced_indices = F0_real[bib] > 5.0
f0_voiced = F0_real[bib][voiced_indices]
if len(f0_voiced) != 0:
# Convert to log scale
log_f0 = f0_voiced.log2()
# Calculate mean and standard deviation
mean_f0 = log_f0.mean()
std_f0 = log_f0.std()
# Normalize the F0 sequence
normalized_f0 = (log_f0 - mean_f0) / std_f0
# Create the normalized F0 sequence with unvoiced frames
normalized_sequence = torch.zeros_like(F0_real[bib])
normalized_sequence[voiced_indices] = normalized_f0
normalized_sequence[~voiced_indices] = (
-10
) # Assign -10 to unvoiced frames
gt_glob_f0s.append(mean_f0)
else:
normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0
gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device))
# f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
f0_targets.append(normalized_sequence)
f0_targets = torch.stack(f0_targets).to(self.accelerator.device)
# fill nan with -10
f0_targets[torch.isnan(f0_targets)] = -10.0
# fill inf with -10
f0_targets[torch.isinf(f0_targets)] = -10.0
# if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
if self.cfg.preprocess_params.frame_rate != 80:
f0_targets = F.interpolate(
f0_targets.unsqueeze(1),
mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
mode="nearest",
).squeeze(1)
w2v_seg = F.interpolate(
w2v_seg,
mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
mode="nearest",
)
wav_seg_input = wav_seg
wav_seg_target = wav_seg
z = self.model.encoder(wav_seg_input)
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths
)
preds, rev_preds = self.model.fa_predictors(quantized, timbre)
pred_wave = self.model.decoder(z)
len_diff = wav_seg_target.size(-1) - pred_wave.size(-1)
if len_diff > 0:
wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2]
# discriminator loss
d_fake = self.model.discriminator(pred_wave.detach())
d_real = self.model.discriminator(wav_seg_target)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += torch.mean(x_fake[-1] ** 2)
loss_d += torch.mean((1 - x_real[-1]) ** 2)
self.optimizer.zero_grad()
self.accelerator.backward(loss_d)
grad_norm_d = torch.nn.utils.clip_grad_norm_(
self.model.discriminator.parameters(), 10.0
)
self.optimizer.step("discriminator")
self.optimizer.scheduler(key="discriminator")
# generator loss
signal = AudioSignal(wav_seg_target, sample_rate=24000)
recons = AudioSignal(pred_wave, sample_rate=24000)
stft_loss = self.criterions["stft"](recons, signal)
mel_loss = self.criterions["mel"](recons, signal)
waveform_loss = self.criterions["l1"](recons, signal)
d_fake = self.model.discriminator(pred_wave)
d_real = self.model.discriminator(wav_seg_target)
loss_g = 0
for x_fake in d_fake:
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
pred_f0, pred_uv = preds["f0"], preds["uv"]
rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"]
common_min_size = min(pred_f0.size(-2), f0_targets.size(-1))
f0_targets = f0_targets[..., :common_min_size]
real_norm = real_norm[..., :common_min_size]
f0_loss = F.smooth_l1_loss(
f0_targets, pred_f0.squeeze(-1)[..., :common_min_size]
)
uv_loss = F.smooth_l1_loss(
real_norm, pred_uv.squeeze(-1)[..., :common_min_size]
)
rev_f0_loss = (
F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size])
if rev_pred_f0 is not None
else torch.FloatTensor([0]).to(self.accelerator.device)
)
rev_uv_loss = (
F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size])
if rev_pred_uv is not None
else torch.FloatTensor([0]).to(self.accelerator.device)
)
tot_f0_loss = f0_loss + rev_f0_loss
tot_uv_loss = uv_loss + rev_uv_loss
pred_content = preds["content"]
rev_pred_content = rev_preds["rev_content"]
target_content_latents = w2v_seg[..., :common_min_size]
content_loss = self.criterions["content"](
pred_content.transpose(1, 2)[..., :common_min_size],
target_content_latents.long(),
)
rev_content_loss = (
self.criterions["content"](
rev_pred_content.transpose(1, 2)[..., :common_min_size],
target_content_latents.long(),
)
if rev_pred_content is not None
else torch.FloatTensor([0]).to(self.accelerator.device)
)
tot_content_loss = content_loss + rev_content_loss
if self.speaker_model is not None:
spk_logits = torch.cat(
[
self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1]
for w16, wl in zip(waves_16k, wave_lengths)
],
dim=0,
)
spk_labels = spk_logits.argmax(dim=-1)
else:
spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to(
self.accelerator.device
)
spk_pred_logits = preds["timbre"]
spk_loss = F.cross_entropy(spk_pred_logits, spk_labels)
x_spk_pred_logits = rev_preds["x_timbre"]
x_spk_loss = (
F.cross_entropy(x_spk_pred_logits, spk_labels)
if x_spk_pred_logits is not None
else torch.FloatTensor([0]).to(self.accelerator.device)
)
tot_spk_loss = spk_loss + x_spk_loss
loss_gen_all = (
mel_loss * 15.0
+ loss_feature * 1.0
+ loss_g * 1.0
+ commitment_loss * 0.25
+ codebook_loss * 1.0
+ tot_f0_loss * 1.0
+ tot_uv_loss * 1.0
+ tot_content_loss * 5.0
+ tot_spk_loss * 5.0
)
self.optimizer.zero_grad()
self.accelerator.backward(loss_gen_all)
with torch.no_grad():
total_loss = loss_gen_all.item()
train_losses["stft"] = stft_loss.item()
train_losses["mel"] = mel_loss.item()
train_losses["l1"] = waveform_loss.item()
train_losses["f0"] = f0_loss.item()
train_losses["uv"] = uv_loss.item()
train_losses["content"] = content_loss.item()
train_losses["speaker"] = spk_loss.item()
train_losses["rev_f0"] = rev_f0_loss.item()
train_losses["rev_uv"] = rev_uv_loss.item()
train_losses["rev_content"] = rev_content_loss.item()
train_losses["rev_speaker"] = x_spk_loss.item()
train_losses["feature"] = loss_feature.item()
train_losses["generator"] = loss_g.item()
train_losses["commitment"] = commitment_loss.item()
train_losses["codebook"] = codebook_loss.item()
# discriminators
train_losses["discriminator"] = loss_d.item()
return total_loss, train_losses
def _inference(self, eval_wave):
"""Inference during training for test audios."""
z = self.model.encoder(
eval_wave[None, None, ...].to(self.accelerator.device).float()
)
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks
)
full_pred_wave = self.model.decoder(z)
return full_pred_wave[0]
def _load_model(self, checkpoint_path=None, resume_type="resume"):
"""Load model from checkpoint. If checkpoint_path is None, it will
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
None, it will load the checkpoint specified by checkpoint_path. **Only use this
method after** ``accelerator.prepare()``.
"""
if resume_type == "resume":
if checkpoint_path is None:
available_checkpoints = glob.glob(
os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth")
)
# find the checkpoint that has the highest step number
latest_checkpoint = max(
available_checkpoints,
key=lambda x: int(x.split("_")[-1].split(".")[0]),
)
earliest_checkpoint = min(
available_checkpoints,
key=lambda x: int(x.split("_")[-1].split(".")[0]),
)
# delete the earliest checkpoint
if (
earliest_checkpoint != latest_checkpoint
and self.accelerator.is_main_process
and len(available_checkpoints) > 4
):
os.remove(earliest_checkpoint)
print(f"Removed {earliest_checkpoint}")
else:
latest_checkpoint = checkpoint_path
self.model, self.optimizer, self.epoch, self.step = load_checkpoint(
self.model,
self.optimizer,
latest_checkpoint,
load_only_params=False,
ignore_modules=[],
is_distributed=self.accelerator.num_processes > 1,
)
else:
raise ValueError("Invalid resume type")
return checkpoint_path
def _count_parameters(self):
total_num = sum(
sum(p.numel() for p in self.model[key].parameters()) for key in self.model
)
# trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return total_num
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
"""
Implementation of model from:
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
Convolutional Recurrent Neural Networks" (2019)
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
"""
import torch
from torch import nn
class JDCNet(nn.Module):
"""
Joint Detection and Classification Network model for singing voice melody.
"""
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
super().__init__()
self.num_class = num_class
# input = (b, 1, 31, 513), b = batch size
self.conv_block = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
), # out: (b, 64, 31, 513)
nn.BatchNorm2d(num_features=64),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
)
# res blocks
self.res_block1 = ResBlock(
in_channels=64, out_channels=128
) # (b, 128, 31, 128)
self.res_block2 = ResBlock(
in_channels=128, out_channels=192
) # (b, 192, 31, 32)
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
# pool block
self.pool_block = nn.Sequential(
nn.BatchNorm2d(num_features=256),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
nn.Dropout(p=0.2),
)
# maxpool layers (for auxiliary network inputs)
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
self.detector_conv = nn.Sequential(
nn.Conv2d(640, 256, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.Dropout(p=0.2),
)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
self.bilstm_classifier = nn.LSTM(
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
) # (b, 31, 512)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
self.bilstm_detector = nn.LSTM(
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
) # (b, 31, 512)
# input: (b * 31, 512)
self.classifier = nn.Linear(
in_features=512, out_features=self.num_class
) # (b * 31, num_class)
# input: (b * 31, 512)
self.detector = nn.Linear(
in_features=512, out_features=2
) # (b * 31, 2) - binary classifier
# initialize weights
self.apply(self.init_weights)
def get_feature_GAN(self, x):
seq_len = x.shape[-2]
x = x.float().transpose(-1, -2)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
return poolblock_out.transpose(-1, -2)
def get_feature(self, x):
seq_len = x.shape[-2]
x = x.float().transpose(-1, -2)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
return self.pool_block[2](poolblock_out)
def forward(self, x):
"""
Returns:
classification_prediction, detection_prediction
sizes: (b, 31, 722), (b, 31, 2)
"""
###############################
# forward pass for classifier #
###############################
seq_len = x.shape[-1]
x = x.float().transpose(-1, -2)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
GAN_feature = poolblock_out.transpose(-1, -2)
poolblock_out = self.pool_block[2](poolblock_out)
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
classifier_out = (
poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
)
classifier_out, _ = self.bilstm_classifier(
classifier_out
) # ignore the hidden states
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
classifier_out = self.classifier(classifier_out)
classifier_out = classifier_out.view(
(-1, seq_len, self.num_class)
) # (b, 31, num_class)
# sizes: (b, 31, 722), (b, 31, 2)
# classifier output consists of predicted pitch classes per frame
# detector output consists of: (isvoice, notvoice) estimates per frame
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
@staticmethod
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
for p in m.parameters():
if p.data is None:
continue
if len(p.shape) >= 2:
nn.init.orthogonal_(p.data)
else:
nn.init.normal_(p.data)
class ResBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
super().__init__()
self.downsample = in_channels != out_channels
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
self.pre_conv = nn.Sequential(
nn.BatchNorm2d(num_features=in_channels),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
)
# conv layers
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
)
# 1 x 1 convolution layer to match the feature dimensions
self.conv1by1 = None
if self.downsample:
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
def forward(self, x):
x = self.pre_conv(x)
if self.downsample:
x = self.conv(x) + self.conv1by1(x)
else:
x = self.conv(x) + x
return x
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
import copy
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from . import commons
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import os.path
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from munch import Munch
import json
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def slice_segments_audio(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
dtype=torch.long
)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def log_norm(x, mean=-4, std=4, dim=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
return x
from huggingface_hub import hf_hub_download
def load_F0_models(path):
# load F0 model
from .JDC.model import JDCNet
F0_model = JDCNet(num_class=1, seq_len=192)
if not os.path.exists(path):
path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
params = torch.load(path, map_location="cpu")["net"]
F0_model.load_state_dict(params)
_ = F0_model.train()
return F0_model
# Generators
from modules.dac.model.dac import Encoder, Decoder
from .quantize import FAquantizer, FApredictors
# Discriminators
from modules.dac.model.discriminator import Discriminator
def build_model(args):
encoder = Encoder(
d_model=args.DAC.encoder_dim,
strides=args.DAC.encoder_rates,
d_latent=1024,
causal=args.causal,
lstm=args.lstm,
)
quantizer = FAquantizer(
in_dim=1024,
n_p_codebooks=1,
n_c_codebooks=args.n_c_codebooks,
n_t_codebooks=2,
n_r_codebooks=3,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0.5,
causal=args.causal,
separate_prosody_encoder=args.separate_prosody_encoder,
timbre_norm=args.timbre_norm,
)
fa_predictors = FApredictors(
in_dim=1024,
use_gr_content_f0=args.use_gr_content_f0,
use_gr_prosody_phone=args.use_gr_prosody_phone,
use_gr_residual_f0=True,
use_gr_residual_phone=True,
use_gr_timbre_content=True,
use_gr_timbre_prosody=args.use_gr_timbre_prosody,
use_gr_x_timbre=True,
norm_f0=args.norm_f0,
timbre_norm=args.timbre_norm,
use_gr_content_global_f0=args.use_gr_content_global_f0,
)
decoder = Decoder(
input_channel=1024,
channels=args.DAC.decoder_dim,
rates=args.DAC.decoder_rates,
causal=args.causal,
lstm=args.lstm,
)
discriminator = Discriminator(
rates=[],
periods=[2, 3, 5, 7, 11],
fft_sizes=[2048, 1024, 512],
sample_rate=args.DAC.sr,
bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
)
nets = Munch(
encoder=encoder,
quantizer=quantizer,
decoder=decoder,
discriminator=discriminator,
fa_predictors=fa_predictors,
)
return nets
def load_checkpoint(
model,
optimizer,
path,
load_only_params=True,
ignore_modules=[],
is_distributed=False,
):
state = torch.load(path, map_location="cpu")
params = state["net"]
for key in model:
if key in params and key not in ignore_modules:
if not is_distributed:
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
for k in list(params[key].keys()):
if k.startswith("module."):
params[key][k[len("module.") :]] = params[key][k]
del params[key][k]
print("%s loaded" % key)
model[key].load_state_dict(params[key], strict=True)
_ = [model[key].eval() for key in model]
if not load_only_params:
epoch = state["epoch"] + 1
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
optimizer.load_scheduler_state_dict(state["scheduler"])
else:
epoch = state["epoch"] + 1
iters = state["iters"]
return model, optimizer, epoch, iters
def recursive_munch(d):
if isinstance(d, dict):
return Munch((k, recursive_munch(v)) for k, v in d.items())
elif isinstance(d, list):
return [recursive_munch(v) for v in d]
else:
return d
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.autograd import Function
import torch
from torch import nn
class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x
@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -alpha * grad_output
return grad_input, None
revgrad = GradientReversal.apply
class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self.alpha)
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import nn
from typing import Optional, Any
from torch import Tensor
import torch.nn.functional as F
import torchaudio
import torchaudio.functional as audio_F
import random
random.seed(0)
def _get_activation_fn(activ):
if activ == "relu":
return nn.ReLU()
elif activ == "lrelu":
return nn.LeakyReLU(0.2)
elif activ == "swish":
return lambda x: x * torch.sigmoid(x)
else:
raise RuntimeError(
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
param=None,
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
)
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class CausualConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=1,
dilation=1,
bias=True,
w_init_gain="linear",
param=None,
):
super(CausualConv, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2) * 2
else:
self.padding = padding * 2
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
)
def forward(self, x):
x = self.conv(x)
x = x[:, :, : -self.padding]
return x
class CausualBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
super(CausualBlock, self).__init__()
self.blocks = nn.ModuleList(
[
self._get_conv(
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
)
for i in range(n_conv)
]
)
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
layers = [
CausualConv(
hidden_dim,
hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
_get_activation_fn(activ),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(p=dropout_p),
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p),
]
return nn.Sequential(*layers)
class ConvBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
super().__init__()
self._n_groups = 8
self.blocks = nn.ModuleList(
[
self._get_conv(
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
)
for i in range(n_conv)
]
)
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
layers = [
ConvNorm(
hidden_dim,
hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
_get_activation_fn(activ),
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
nn.Dropout(p=dropout_p),
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p),
]
return nn.Sequential(*layers)
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(
2,
attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding,
bias=False,
stride=1,
dilation=1,
)
self.location_dense = LinearNorm(
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
)
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(Attention, self).__init__()
self.query_layer = LinearNorm(
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = LinearNorm(
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size, attention_dim
)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = energies.squeeze(-1)
return energies
def forward(
self,
attention_hidden_state,
memory,
processed_memory,
attention_weights_cat,
mask,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class ForwardAttentionV2(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(ForwardAttentionV2, self).__init__()
self.query_layer = LinearNorm(
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = LinearNorm(
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size, attention_dim
)
self.score_mask_value = -float(1e20)
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = energies.squeeze(-1)
return energies
def forward(
self,
attention_hidden_state,
memory,
processed_memory,
attention_weights_cat,
mask,
log_alpha,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
log_energy = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
# log_energy =
if mask is not None:
log_energy.data.masked_fill_(mask, self.score_mask_value)
# attention_weights = F.softmax(alignment, dim=1)
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
# log_total_score = log_alpha + content_score
# previous_attention_weights = attention_weights_cat[:,0,:]
log_alpha_shift_padded = []
max_time = log_energy.size(1)
for sft in range(2):
shifted = log_alpha[:, : max_time - sft]
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
log_alpha_new = biased + log_energy
attention_weights = F.softmax(log_alpha_new, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights, log_alpha_new
class PhaseShuffle2d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle2d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = torch.cat([right, left], dim=3)
return shuffled
class PhaseShuffle1d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle1d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :move]
right = x[:, :, move:]
shuffled = torch.cat([right, left], dim=2)
return shuffled
class MFCC(nn.Module):
def __init__(self, n_mfcc=40, n_mels=80):
super(MFCC, self).__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = "ortho"
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer("dct_mat", dct_mat)
def forward(self, mel_specgram):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from modules.dac.nn.quantize import ResidualVectorQuantize
from torch import nn
from .wavenet import WN
from .style_encoder import StyleEncoder
from .gradient_reversal import GradientReversal
import torch
import torchaudio
import torchaudio.functional as audio_F
import numpy as np
from ..alias_free_torch import *
from torch.nn.utils import weight_norm
from torch import nn, sin, pow
from einops.layers.torch import Rearrange
from modules.dac.model.encodec import SConv1d
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta := x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
return x + self.block(x)
class CNNLSTM(nn.Module):
def __init__(self, indim, outdim, head, global_pred=False):
super().__init__()
self.global_pred = global_pred
self.model = nn.Sequential(
ResidualUnit(indim, dilation=1),
ResidualUnit(indim, dilation=2),
ResidualUnit(indim, dilation=3),
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
Rearrange("b c t -> b t c"),
)
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
def forward(self, x):
# x: [B, C, T]
x = self.model(x)
if self.global_pred:
x = torch.mean(x, dim=1, keepdim=False)
outs = [head(x) for head in self.heads]
return outs
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
class MFCC(nn.Module):
def __init__(self, n_mfcc=40, n_mels=80):
super(MFCC, self).__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = "ortho"
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer("dct_mat", dct_mat)
def forward(self, mel_specgram):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc
class FAquantizer(nn.Module):
def __init__(
self,
in_dim=1024,
n_p_codebooks=1,
n_c_codebooks=2,
n_t_codebooks=2,
n_r_codebooks=3,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0.5,
causal=False,
separate_prosody_encoder=False,
timbre_norm=False,
):
super(FAquantizer, self).__init__()
conv1d_type = SConv1d # if causal else nn.Conv1d
self.prosody_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_p_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.content_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_c_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
if not timbre_norm:
self.timbre_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_t_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
else:
self.timbre_encoder = StyleEncoder(
in_dim=80, hidden_dim=512, out_dim=in_dim
)
self.timbre_linear = nn.Linear(1024, 1024 * 2)
self.timbre_linear.bias.data[:1024] = 1
self.timbre_linear.bias.data[1024:] = 0
self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
self.residual_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_r_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
if separate_prosody_encoder:
self.melspec_linear = conv1d_type(
in_channels=20, out_channels=256, kernel_size=1, causal=causal
)
self.melspec_encoder = WN(
hidden_channels=256,
kernel_size=5,
dilation_rate=1,
n_layers=8,
gin_channels=0,
p_dropout=0.2,
causal=causal,
)
self.melspec_linear2 = conv1d_type(
in_channels=256, out_channels=1024, kernel_size=1, causal=causal
)
else:
pass
self.separate_prosody_encoder = separate_prosody_encoder
self.prob_random_mask_residual = 0.75
SPECT_PARAMS = {
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
}
MEL_PARAMS = {
"n_mels": 80,
}
self.to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
)
self.mel_mean, self.mel_std = -4, 4
self.frame_rate = 24000 / 300
self.hop_length = 300
self.is_timbre_norm = timbre_norm
if timbre_norm:
self.forward = self.forward_v2
def preprocess(self, wave_tensor, n_bins=20):
mel_tensor = self.to_mel(wave_tensor.squeeze(1))
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
@torch.no_grad()
def decode(self, codes):
code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
z_c = self.content_quantizer.from_codes(code_c)[0]
z_p = self.prosody_quantizer.from_codes(code_p)[0]
z_t = self.timbre_quantizer.from_codes(code_t)[0]
z = z_c + z_p + z_t
return z, [z_c, z_p, z_t]
@torch.no_grad()
def encode(self, x, wave_segments, n_c=1):
outs = 0
if self.separate_prosody_encoder:
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(
f0_input,
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
.to(f0_input.device)
.bool(),
)
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(f0_input, 1)
outs += z_p.detach()
else:
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(x, 1)
outs += z_p.detach()
(
z_c,
codes_c,
latents_c,
commitment_loss_c,
codebook_loss_c,
) = self.content_quantizer(x, n_c)
outs += z_c.detach()
timbre_residual_feature = x - z_p.detach() - z_c.detach()
(
z_t,
codes_t,
latents_t,
commitment_loss_t,
codebook_loss_t,
) = self.timbre_quantizer(timbre_residual_feature, 2)
outs += z_t # we should not detach timbre
residual_feature = timbre_residual_feature - z_t
(
z_r,
codes_r,
latents_r,
commitment_loss_r,
codebook_loss_r,
) = self.residual_quantizer(residual_feature, 3)
return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
def forward(
self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
):
# timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
# timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
outs = 0
if self.separate_prosody_encoder:
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(
f0_input,
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
.to(f0_input.device)
.bool(),
)
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(f0_input, 1)
outs += z_p.detach()
else:
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(x, 1)
outs += z_p.detach()
(
z_c,
codes_c,
latents_c,
commitment_loss_c,
codebook_loss_c,
) = self.content_quantizer(x, n_c)
outs += z_c.detach()
timbre_residual_feature = x - z_p.detach() - z_c.detach()
(
z_t,
codes_t,
latents_t,
commitment_loss_t,
codebook_loss_t,
) = self.timbre_quantizer(timbre_residual_feature, n_t)
outs += z_t # we should not detach timbre
residual_feature = timbre_residual_feature - z_t
(
z_r,
codes_r,
latents_r,
commitment_loss_r,
codebook_loss_r,
) = self.residual_quantizer(residual_feature, 3)
bsz = z_r.shape[0]
res_mask = np.random.choice(
[0, 1],
size=bsz,
p=[
self.prob_random_mask_residual,
1 - self.prob_random_mask_residual,
],
)
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
noise_must_on = noise_added_flags * recon_noisy_flags
noise_must_off = noise_added_flags * (~recon_noisy_flags)
res_mask[noise_must_on] = 1
res_mask[noise_must_off] = 0
outs += z_r * res_mask
quantized = [z_p, z_c, z_t, z_r]
commitment_losses = (
commitment_loss_p
+ commitment_loss_c
+ commitment_loss_t
+ commitment_loss_r
)
codebook_losses = (
codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
)
return outs, quantized, commitment_losses, codebook_losses
def forward_v2(
self,
x,
wave_segments,
n_c=1,
n_t=2,
full_waves=None,
wave_lens=None,
return_codes=False,
):
# timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
if full_waves is None:
mel = self.preprocess(wave_segments, n_bins=80)
timbre = self.timbre_encoder(
mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
)
else:
mel = self.preprocess(full_waves, n_bins=80)
timbre = self.timbre_encoder(
mel,
sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
)
outs = 0
if self.separate_prosody_encoder:
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(
f0_input,
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
.to(f0_input.device)
.bool(),
)
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(f0_input, 1)
outs += z_p.detach()
else:
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(x, 1)
outs += z_p.detach()
(
z_c,
codes_c,
latents_c,
commitment_loss_c,
codebook_loss_c,
) = self.content_quantizer(x, n_c)
outs += z_c.detach()
residual_feature = x - z_p.detach() - z_c.detach()
(
z_r,
codes_r,
latents_r,
commitment_loss_r,
codebook_loss_r,
) = self.residual_quantizer(residual_feature, 3)
bsz = z_r.shape[0]
res_mask = np.random.choice(
[0, 1],
size=bsz,
p=[
self.prob_random_mask_residual,
1 - self.prob_random_mask_residual,
],
)
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
if not self.training:
res_mask = torch.ones_like(res_mask)
outs += z_r * res_mask
quantized = [z_p, z_c, z_r]
codes = [codes_p, codes_c, codes_r]
commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
outs = outs.transpose(1, 2)
outs = self.timbre_norm(outs)
outs = outs.transpose(1, 2)
outs = outs * gamma + beta
if return_codes:
return outs, quantized, commitment_losses, codebook_losses, timbre, codes
else:
return outs, quantized, commitment_losses, codebook_losses, timbre
def voice_conversion(self, z, ref_wave):
ref_mel = self.preprocess(ref_wave, n_bins=80)
ref_timbre = self.timbre_encoder(
ref_mel,
sequence_mask(
torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
ref_mel.size(-1),
).unsqueeze(1),
)
style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
outs = z.transpose(1, 2)
outs = self.timbre_norm(outs)
outs = outs.transpose(1, 2)
outs = outs * gamma + beta
return outs
class FApredictors(nn.Module):
def __init__(
self,
in_dim=1024,
use_gr_content_f0=False,
use_gr_prosody_phone=False,
use_gr_residual_f0=False,
use_gr_residual_phone=False,
use_gr_timbre_content=True,
use_gr_timbre_prosody=True,
use_gr_x_timbre=False,
norm_f0=True,
timbre_norm=False,
use_gr_content_global_f0=False,
):
super(FApredictors, self).__init__()
self.f0_predictor = CNNLSTM(in_dim, 1, 2)
self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
if timbre_norm:
self.timbre_predictor = nn.Linear(in_dim, 20000)
else:
self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
self.use_gr_content_f0 = use_gr_content_f0
self.use_gr_prosody_phone = use_gr_prosody_phone
self.use_gr_residual_f0 = use_gr_residual_f0
self.use_gr_residual_phone = use_gr_residual_phone
self.use_gr_timbre_content = use_gr_timbre_content
self.use_gr_timbre_prosody = use_gr_timbre_prosody
self.use_gr_x_timbre = use_gr_x_timbre
self.rev_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
)
self.rev_content_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
)
self.rev_timbre_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
)
self.norm_f0 = norm_f0
self.timbre_norm = timbre_norm
if timbre_norm:
self.forward = self.forward_v2
self.global_f0_predictor = nn.Linear(in_dim, 1)
self.use_gr_content_global_f0 = use_gr_content_global_f0
if use_gr_content_global_f0:
self.rev_global_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
)
def forward(self, quantized):
prosody_latent = quantized[0]
content_latent = quantized[1]
timbre_latent = quantized[2]
residual_latent = quantized[3]
content_pred = self.phone_predictor(content_latent)[0]
if self.norm_f0:
spk_pred = self.timbre_predictor(timbre_latent)[0]
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
else:
spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
prosody_rev_latent = torch.zeros_like(quantized[0])
if self.use_gr_content_f0:
prosody_rev_latent += quantized[1]
if self.use_gr_timbre_prosody:
prosody_rev_latent += quantized[2]
if self.use_gr_residual_f0:
prosody_rev_latent += quantized[3]
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
content_rev_latent = torch.zeros_like(quantized[1])
if self.use_gr_prosody_phone:
content_rev_latent += quantized[0]
if self.use_gr_timbre_content:
content_rev_latent += quantized[2]
if self.use_gr_residual_phone:
content_rev_latent += quantized[3]
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
if self.norm_f0:
timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
else:
timbre_rev_latent = quantized[1] + quantized[3]
if self.use_gr_x_timbre:
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
else:
x_spk_pred = None
preds = {
"f0": f0_pred,
"uv": uv_pred,
"content": content_pred,
"timbre": spk_pred,
}
rev_preds = {
"rev_f0": rev_f0_pred,
"rev_uv": rev_uv_pred,
"rev_content": rev_content_pred,
"x_timbre": x_spk_pred,
}
return preds, rev_preds
def forward_v2(self, quantized, timbre):
prosody_latent = quantized[0]
content_latent = quantized[1]
residual_latent = quantized[2]
content_pred = self.phone_predictor(content_latent)[0]
spk_pred = self.timbre_predictor(timbre)
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
prosody_rev_latent = torch.zeros_like(prosody_latent)
if self.use_gr_content_f0:
prosody_rev_latent += content_latent
if self.use_gr_residual_f0:
prosody_rev_latent += residual_latent
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
content_rev_latent = torch.zeros_like(content_latent)
if self.use_gr_prosody_phone:
content_rev_latent += prosody_latent
if self.use_gr_residual_phone:
content_rev_latent += residual_latent
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
timbre_rev_latent = prosody_latent + content_latent + residual_latent
if self.use_gr_x_timbre:
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
else:
x_spk_pred = None
preds = {
"f0": f0_pred,
"uv": uv_pred,
"content": content_pred,
"timbre": spk_pred,
}
rev_preds = {
"rev_f0": rev_f0_pred,
"rev_uv": rev_uv_pred,
"rev_content": rev_content_pred,
"x_timbre": x_spk_pred,
}
return preds, rev_preds
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
from . import attentions
from torch import nn
import torch
from torch.nn import functional as F
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class Conv1dGLU(nn.Module):
"""
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
"""
def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__()
self.out_channels = out_channels
self.conv1 = nn.Conv1d(
in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.conv1(x)
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
x = x1 * torch.sigmoid(x2)
x = residual + self.dropout(x)
return x
class StyleEncoder(torch.nn.Module):
def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
super().__init__()
self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.kernel_size = 5
self.n_head = 2
self.dropout = 0.1
self.spectral = nn.Sequential(
nn.Conv1d(self.in_dim, self.hidden_dim, 1),
Mish(),
nn.Dropout(self.dropout),
nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
Mish(),
nn.Dropout(self.dropout),
)
self.temporal = nn.Sequential(
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = attentions.MultiHeadAttention(
self.hidden_dim,
self.hidden_dim,
self.n_head,
p_dropout=self.dropout,
proximal_bias=False,
proximal_init=True,
)
self.atten_drop = nn.Dropout(self.dropout)
self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
def forward(self, x, mask=None):
# spectral
x = self.spectral(x) * mask
# temporal
x = self.temporal(x) * mask
# self-attention
attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
y = self.slf_attn(x, x, attn_mask=attn_mask)
x = x + self.atten_drop(y)
# fc
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w
def temporal_avg_pool(self, x, mask=None):
if mask is None:
out = torch.mean(x, dim=2)
else:
len_ = mask.sum(dim=2)
x = x.sum(dim=2)
out = torch.div(x, len_)
return out
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
import math
import torch
from torch import nn
from torch.nn import functional as F
from modules.dac.model.encodec import SConv1d
from . import commons
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
causal=False,
):
super(WN, self).__init__()
conv1d_type = SConv1d
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
self.cond_layer = conv1d_type(
gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
)
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = conv1d_type(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
norm="weight_norm",
causal=causal,
)
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = conv1d_type(
hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
)
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment