Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
import warnings
from pathlib import Path
import argbind
import numpy as np
import torch
from audiotools import AudioSignal
from tqdm import tqdm
from dac import DACFile
from dac.utils import load_model
warnings.filterwarnings("ignore", category=UserWarning)
@argbind.bind(group="decode", positional=True, without_prefix=True)
@torch.inference_mode()
@torch.no_grad()
def decode(
input: str,
output: str = "",
weights_path: str = "",
model_tag: str = "latest",
model_bitrate: str = "8kbps",
device: str = "cuda",
model_type: str = "44khz",
verbose: bool = False,
):
"""Decode audio from codes.
Parameters
----------
input : str
Path to input directory or file
output : str, optional
Path to output directory, by default "".
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
weights_path : str, optional
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
model_tag and model_type.
model_tag : str, optional
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
model_bitrate: str
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
device : str, optional
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
model_type : str, optional
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
"""
generator = load_model(
model_type=model_type,
model_bitrate=model_bitrate,
tag=model_tag,
load_path=weights_path,
)
generator.to(device)
generator.eval()
# Find all .dac files in input directory
_input = Path(input)
input_files = list(_input.glob("**/*.dac"))
# If input is a .dac file, add it to the list
if _input.suffix == ".dac":
input_files.append(_input)
# Create output directory
output = Path(output)
output.mkdir(parents=True, exist_ok=True)
for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
# Load file
artifact = DACFile.load(input_files[i])
# Reconstruct audio from codes
recons = generator.decompress(artifact, verbose=verbose)
# Compute output path
relative_path = input_files[i].relative_to(input)
output_dir = output / relative_path.parent
if not relative_path.name:
output_dir = output
relative_path = input_files[i]
output_name = relative_path.with_suffix(".wav").name
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
# Write to file
recons.write(output_path)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
decode()
import math
import warnings
from pathlib import Path
import argbind
import numpy as np
import torch
from audiotools import AudioSignal
from audiotools.core import util
from tqdm import tqdm
from dac.utils import load_model
warnings.filterwarnings("ignore", category=UserWarning)
@argbind.bind(group="encode", positional=True, without_prefix=True)
@torch.inference_mode()
@torch.no_grad()
def encode(
input: str,
output: str = "",
weights_path: str = "",
model_tag: str = "latest",
model_bitrate: str = "8kbps",
n_quantizers: int = None,
device: str = "cuda",
model_type: str = "44khz",
win_duration: float = 5.0,
verbose: bool = False,
):
"""Encode audio files in input path to .dac format.
Parameters
----------
input : str
Path to input audio file or directory
output : str, optional
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
weights_path : str, optional
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
model_tag and model_type.
model_tag : str, optional
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
model_bitrate: str
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
n_quantizers : int, optional
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
device : str, optional
Device to use, by default "cuda"
model_type : str, optional
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
"""
generator = load_model(
model_type=model_type,
model_bitrate=model_bitrate,
tag=model_tag,
load_path=weights_path,
)
generator.to(device)
generator.eval()
kwargs = {"n_quantizers": n_quantizers}
# Find all audio files in input path
input = Path(input)
audio_files = util.find_audio(input)
output = Path(output)
output.mkdir(parents=True, exist_ok=True)
for i in tqdm(range(len(audio_files)), desc="Encoding files"):
# Load file
signal = AudioSignal(audio_files[i])
# Encode audio to .dac format
artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
# Compute output path
relative_path = audio_files[i].relative_to(input)
output_dir = output / relative_path.parent
if not relative_path.name:
output_dir = output
relative_path = audio_files[i]
output_name = relative_path.with_suffix(".dac").name
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
artifact.save(output_path)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
encode()
import os
from huggingface_hub import hf_hub_download
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
os.makedirs("./checkpoints", exist_ok=True)
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
if config_filename is None:
return model_path
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
return model_path, config_path
\ No newline at end of file
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from scipy.io.wavfile import read
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
# if torch.min(y) < -1.0:
# print("min value is ", torch.min(y))
# if torch.max(y) > 1.0:
# print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from munch import Munch
import json
import argparse
from torch.nn.parallel import DistributedDataParallel as DDP
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def slice_segments_audio(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
dtype=torch.long
)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def avg_with_mask(x, mask):
assert mask.dtype == torch.float, "Mask should be float"
if mask.ndim == 2:
mask = mask.unsqueeze(1)
if mask.shape[1] == 1:
mask = mask.expand_as(x)
return (x * mask).sum() / mask.sum()
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def log_norm(x, mean=-4, std=4, dim=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
return x
def load_F0_models(path):
# load F0 model
from .JDC.model import JDCNet
F0_model = JDCNet(num_class=1, seq_len=192)
params = torch.load(path, map_location="cpu")["net"]
F0_model.load_state_dict(params)
_ = F0_model.train()
return F0_model
def modify_w2v_forward(self, output_layer=15):
"""
change forward method of w2v encoder to get its intermediate layer output
:param self:
:param layer:
:return:
"""
from transformers.modeling_outputs import BaseModelOutput
def forward(
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
conv_attention_mask = attention_mask
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states = hidden_states.masked_fill(
~attention_mask.bool().unsqueeze(-1), 0.0
)
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
dtype=hidden_states.dtype
)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0],
1,
attention_mask.shape[-1],
attention_mask.shape[-1],
)
hidden_states = self.dropout(hidden_states)
if self.embed_positions is not None:
relative_position_embeddings = self.embed_positions(hidden_states)
else:
relative_position_embeddings = None
deepspeed_zero3_is_enabled = False
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = (
True
if self.training and (dropout_probability < self.config.layerdrop)
else False
)
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
conv_attention_mask,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if i == output_layer - 1:
break
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward
MATPLOTLIB_FLAG = False
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
import logging
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def normalize_f0(f0_sequence):
# Remove unvoiced frames (replace with -1)
voiced_indices = np.where(f0_sequence > 0)[0]
f0_voiced = f0_sequence[voiced_indices]
# Convert to log scale
log_f0 = np.log2(f0_voiced)
# Calculate mean and standard deviation
mean_f0 = np.mean(log_f0)
std_f0 = np.std(log_f0)
# Normalize the F0 sequence
normalized_f0 = (log_f0 - mean_f0) / std_f0
# Create the normalized F0 sequence with unvoiced frames
normalized_sequence = np.zeros_like(f0_sequence)
normalized_sequence[voiced_indices] = normalized_f0
normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
return normalized_sequence
class MyModel(nn.Module):
def __init__(self,args):
super(MyModel, self).__init__()
from modules.flow_matching import CFM
from modules.length_regulator import InterpolateRegulator
length_regulator = InterpolateRegulator(
channels=args.length_regulator.channels,
sampling_ratios=args.length_regulator.sampling_ratios,
is_discrete=args.length_regulator.is_discrete,
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
codebook_size=args.length_regulator.content_codebook_size,
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
)
self.models = nn.ModuleDict({
'cfm': CFM(args),
'length_regulator': length_regulator
})
def forward(self, x, target_lengths, prompt_len, cond, y):
x = self.models['cfm'](x, target_lengths, prompt_len, cond, y)
return x
def forward2(self, S_ori,target_lengths,F0_ori):
x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori)
return x
def build_model(args, stage="DiT"):
if stage == "DiT":
from modules.flow_matching import CFM
from modules.length_regulator import InterpolateRegulator
length_regulator = InterpolateRegulator(
channels=args.length_regulator.channels,
sampling_ratios=args.length_regulator.sampling_ratios,
is_discrete=args.length_regulator.is_discrete,
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
codebook_size=args.length_regulator.content_codebook_size,
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
)
cfm = CFM(args)
nets = Munch(
cfm=cfm,
length_regulator=length_regulator,
)
elif stage == 'codec':
from dac.model.dac import Encoder
from modules.quantize import (
FAquantizer,
)
encoder = Encoder(
d_model=args.DAC.encoder_dim,
strides=args.DAC.encoder_rates,
d_latent=1024,
causal=args.causal,
lstm=args.lstm,
)
quantizer = FAquantizer(
in_dim=1024,
n_p_codebooks=1,
n_c_codebooks=args.n_c_codebooks,
n_t_codebooks=2,
n_r_codebooks=3,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0.5,
causal=args.causal,
separate_prosody_encoder=args.separate_prosody_encoder,
timbre_norm=args.timbre_norm,
)
nets = Munch(
encoder=encoder,
quantizer=quantizer,
)
elif stage == "mel_vocos":
from modules.vocos import Vocos
decoder = Vocos(args)
nets = Munch(
decoder=decoder,
)
else:
raise ValueError(f"Unknown stage: {stage}")
return nets
def load_checkpoint(
model,
optimizer,
path,
load_only_params=True,
ignore_modules=[],
is_distributed=False,
load_ema=False,
):
state = torch.load(path, map_location="cpu")
params = state["net"]
if load_ema and "ema" in state:
print("Loading EMA")
for key in model:
i = 0
for param_name in params[key]:
if "input_pos" in param_name:
continue
assert params[key][param_name].shape == state["ema"][key][0][i].shape
params[key][param_name] = state["ema"][key][0][i].clone()
i += 1
for key in model:
if key in params and key not in ignore_modules:
if not is_distributed:
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
for k in list(params[key].keys()):
if k.startswith("module."):
params[key][k[len("module.") :]] = params[key][k]
del params[key][k]
model_state_dict = model[key].state_dict()
# 过滤出形状匹配的键值对
filtered_state_dict = {
k: v
for k, v in params[key].items()
if k in model_state_dict and v.shape == model_state_dict[k].shape
}
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
if skipped_keys:
print(
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
)
print("%s loaded" % key)
model[key].load_state_dict(filtered_state_dict, strict=False)
_ = [model[key].eval() for key in model]
if not load_only_params:
epoch = state["epoch"] + 1
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
optimizer.load_scheduler_state_dict(state["scheduler"])
else:
epoch = 0
iters = 0
return model, optimizer, epoch, iters
def load_checkpoint2(
model,
optimizer,
path,
load_only_params=True,
ignore_modules=[],
is_distributed=False,
load_ema=False,
):
state = torch.load(path, map_location="cpu")
params = state["net"]
if load_ema and "ema" in state:
print("Loading EMA")
for key in model.models:
i = 0
for param_name in params[key]:
if "input_pos" in param_name:
continue
assert params[key][param_name].shape == state["ema"][key][0][i].shape
params[key][param_name] = state["ema"][key][0][i].clone()
i += 1
for key in model.models:
if key in params and key not in ignore_modules:
if not is_distributed:
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
for k in list(params[key].keys()):
if k.startswith("module."):
params[key][k[len("module.") :]] = params[key][k]
del params[key][k]
model_state_dict = model.models[key].state_dict()
# 过滤出形状匹配的键值对
filtered_state_dict = {
k: v
for k, v in params[key].items()
if k in model_state_dict and v.shape == model_state_dict[k].shape
}
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
if skipped_keys:
print(
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
)
print("%s loaded" % key)
model.models[key].load_state_dict(filtered_state_dict, strict=False)
model.eval()
# _ = [model[key].eval() for key in model]
if not load_only_params:
epoch = state["epoch"] + 1
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
optimizer.load_scheduler_state_dict(state["scheduler"])
else:
epoch = 0
iters = 0
return model, optimizer, epoch, iters
def recursive_munch(d):
if isinstance(d, dict):
return Munch((k, recursive_munch(v)) for k, v in d.items())
elif isinstance(d, list):
return [recursive_munch(v) for v in d]
else:
return d
import torch
from torch import nn
import math
from modules.gpt_fast.model import ModelArgs, Transformer
# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
from modules.wavenet import WN
from modules.commons import sequence_mask
from torch.nn.utils import weight_norm
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = 10000
self.scale = 1000
half = frequency_embedding_size // 2
freqs = torch.exp(
-math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
)
self.register_buffer("freqs", freqs)
def timestep_embedding(self, t):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
args = self.scale * t[:, None].float() * self.freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.frequency_embedding_size % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t)
t_emb = self.mlp(t_freq)
return t_emb
class StyleEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, input_size, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
self.input_size = input_size
self.dropout_prob = dropout_prob
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
else:
labels = self.style_in(labels)
embeddings = labels
return embeddings
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT(torch.nn.Module):
def __init__(
self,
args
):
super(DiT, self).__init__()
self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
model_args = ModelArgs(
block_size=16384,#args.DiT.block_size,
n_layer=args.DiT.depth,
n_head=args.DiT.num_heads,
dim=args.DiT.hidden_dim,
head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
vocab_size=1024,
uvit_skip_connection=self.uvit_skip_connection,
time_as_token=self.time_as_token,
)
self.transformer = Transformer(model_args)
self.in_channels = args.DiT.in_channels
self.out_channels = args.DiT.in_channels
self.num_heads = args.DiT.num_heads
self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
self.content_dim = args.DiT.content_dim # for continuous content
self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
self.is_causal = args.DiT.is_causal
self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
# self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
# self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
input_pos = torch.arange(16384)
self.register_buffer("input_pos", input_pos)
self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
if self.final_layer_type == 'wavenet':
self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
kernel_size=args.wavenet.kernel_size,
dilation_rate=args.wavenet.dilation_rate,
n_layers=args.wavenet.num_layers,
gin_channels=args.wavenet.hidden_dim,
p_dropout=args.wavenet.p_dropout,
causal=False)
self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
self.res_projection = nn.Linear(args.DiT.hidden_dim,
args.wavenet.hidden_dim) # residual connection from tranformer output to final output
self.wavenet_style_condition = args.wavenet.style_condition
assert args.DiT.style_condition == args.wavenet.style_condition
else:
self.final_mlp = nn.Sequential(
nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
nn.SiLU(),
nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
)
self.transformer_style_condition = args.DiT.style_condition
self.class_dropout_prob = args.DiT.class_dropout_prob
self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
self.long_skip_connection = args.DiT.long_skip_connection
self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
args.DiT.hidden_dim)
if self.style_as_token:
self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
def setup_caches(self, max_batch_size, max_seq_length):
self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
"""
x (torch.Tensor): random noise
prompt_x (torch.Tensor): reference mel + zero mel
shape: (batch_size, 80, 795+1068)
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
t (torch.Tensor): radshape:
shape: (batch_size)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
cond (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
"""
class_dropout = False
if self.training and torch.rand(1) < self.class_dropout_prob:
class_dropout = True
if not self.training and mask_content:
class_dropout = True
# cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
cond_in_module = self.cond_projection
B, _, T = x.size()
t1 = self.t_embedder(t) # (N, D) # t1 [2, 512]
cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
x = x.transpose(1, 2) # [2,1863,80]
prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
if self.transformer_style_condition and not self.style_as_token: # True and True
x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
if class_dropout: #False
x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512]
if self.style_as_token: # False
style = self.style_in(style)
style = torch.zeros_like(style) if class_dropout else style
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
if self.time_as_token: # False
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
x_res = x_res[:, 1:] if self.time_as_token else x_res
x_res = x_res[:, 1:] if self.style_as_token else x_res
if self.long_skip_connection: #True
x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
if self.final_layer_type == 'wavenet':
x = self.conv1(x_res)
x = x.transpose(1, 2)
t2 = self.t_embedder2(t)
x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
x_res) # long residual connection
x = self.final_layer(x, t1).transpose(1, 2)
x = self.conv2(x)
else:
x = self.final_mlp(x_res)
x = x.transpose(1, 2)
# x [2,80,1863]
return x
from abc import ABC
import torch
import torch.nn.functional as F
from modules.diffusion_transformer import DiT
from modules.commons import sequence_mask
from tqdm import tqdm
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
args,
):
super().__init__()
self.sigma_min = 1e-6
self.estimator = None
self.in_channels = args.DiT.in_channels
self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
if hasattr(args.DiT, 'zero_prompt_speech_token'):
self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
else:
self.zero_prompt_speech_token = False
@torch.inference_mode()
def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
"""Forward diffusion
Args:
mu (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
prompt (torch.Tensor): reference mel
shape: (batch_size, 80, 795)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
f0: None
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, 80, mel_timesteps)
"""
B, T = mu.size(0), mu.size(1)
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
# t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
prompt (torch.Tensor): reference mel
shape: (batch_size, 80, 795)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
"""
t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# apply prompt
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0
if self.zero_prompt_speech_token:
mu[..., :prompt_len] = 0
for step in tqdm(range(1, len(t_span))):
dt = t_span[step] - t_span[step - 1]
if inference_cfg_rate > 0:
# Stack original and CFG (null) inputs for batched processing
stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
stacked_x = torch.cat([x, x], dim=0)
stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
# Perform a single forward pass for both original and CFG inputs
stacked_dphi_dt = self.estimator(
stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
)
# Split the output back into the original and CFG components
dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
# Apply CFG formula
dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
else:
dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
x[:, :, :prompt_len] = 0
return sol[-1]
def forward(self, x1, x_lens, prompt_lens, mu, style):
"""Computes diffusion loss
Args:
mu (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
x1: mel
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
prompt (torch.Tensor): reference mel
shape: (batch_size, 80, 795)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = x1.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
prompt = torch.zeros_like(x1)
for bib in range(b):
prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
# range covered by prompt are set to 0
y[bib, :, :prompt_lens[bib]] = 0
if self.zero_prompt_speech_token:
mu[bib, :, :prompt_lens[bib]] = 0
estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
loss = 0
for bib in range(b):
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
loss /= b
return loss, estimator_out + (1 - self.sigma_min) * z
class CFM(BASECFM):
def __init__(self, args):
super().__init__(
args
)
if args.dit_type == "DiT":
self.estimator = DiT(args)
else:
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from modules.commons import sequence_mask
import numpy as np
from dac.nn.quantize import VectorQuantize
# f0_bin = 256
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
def f0_to_coarse(f0, f0_bin):
f0_mel = 1127 * (1 + f0 / 700).log()
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
b = f0_mel_min * a - 1.
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
# torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
f0_coarse = torch.round(f0_mel).long()
f0_coarse = f0_coarse * (f0_coarse > 0)
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
f0_coarse = f0_coarse * (f0_coarse < f0_bin)
f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
return f0_coarse
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
is_discrete: bool = False,
in_channels: int = None, # only applies to continuous input
vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
codebook_size: int = 1024, # for discrete only
out_channels: int = None,
groups: int = 1,
n_codebooks: int = 1, # number of codebooks
quantizer_dropout: float = 0.0, # dropout for quantizer
f0_condition: bool = False,
n_f0_bins: int = 512,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
self.interpolate = True
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
else:
self.interpolate = False
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
self.embedding = nn.Embedding(codebook_size, channels)
self.is_discrete = is_discrete
self.mask_token = nn.Parameter(torch.zeros(1, channels))
self.n_codebooks = n_codebooks
if n_codebooks > 1:
self.extra_codebooks = nn.ModuleList([
nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
])
self.extra_codebook_mask_tokens = nn.ParameterList([
nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
])
self.quantizer_dropout = quantizer_dropout
if f0_condition:
self.f0_embedding = nn.Embedding(n_f0_bins, channels)
self.f0_condition = f0_condition
self.n_f0_bins = n_f0_bins
self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
self.f0_mask = nn.Parameter(torch.zeros(1, channels))
else:
self.f0_condition = False
if not is_discrete:
self.content_in_proj = nn.Linear(in_channels, channels)
if vector_quantize:
self.vq = VectorQuantize(channels, codebook_size, 8)
def forward(self, x, ylens=None, n_quantizers=None, f0=None):
# apply token drop
if self.training:
n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
n_dropout = int(x.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(x.device)
# decide whether to drop for each sample in batch
else:
n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
if self.is_discrete:
if self.n_codebooks > 1:
assert len(x.size()) == 3
x_emb = self.embedding(x[:, 0])
for i, emb in enumerate(self.extra_codebooks):
x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
# add mask token if not using this codebook
# x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
x = x_emb
elif self.n_codebooks == 1:
if len(x.size()) == 2:
x = self.embedding(x)
else:
x = self.embedding(x[:, 0])
else:
x = self.content_in_proj(x)
# x in (B, T, D)
mask = sequence_mask(ylens).unsqueeze(-1)
if self.interpolate:
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
else:
x = x.transpose(1, 2).contiguous()
mask = mask[:, :x.size(2), :]
ylens = ylens.clamp(max=x.size(2)).long()
if self.f0_condition:
if f0 is None:
x = x + self.f0_mask.unsqueeze(-1)
else:
#quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
f0_emb = self.f0_embedding(quantized_f0)
f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
x = x + f0_emb
out = self.model(x).transpose(1, 2).contiguous()
if hasattr(self, 'vq'):
out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
out_q = out_q.transpose(1, 2)
return out_q * mask, ylens, codes, commitment_loss, codebook_loss
olens = ylens
return out * mask, olens, None, None, None
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
from .filter import *
from .resample import *
from .act import *
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if "sinc" in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2
# For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
# input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
return out
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
xx = self.lowpass(x)
return xx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment