Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizer import get_tokenizer
try:
from GPTQ import GenericGPTQRunner, InputRecorder
from eval import get_task_dict, evaluate, lm_eval
except:
pass
from model import Transformer
##### Quantization Primitives ######
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
# get min and max
min_val, max_val = torch.aminmax(x, dim=1)
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
return quant, scales, zero_points
def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2
w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
return group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit, groupsize
)
class QuantHandler:
def __init__(self, mod):
self.mod = mod
def create_quantized_state_dict(self) -> "StateDict":
pass
def convert_for_runtime(self) -> "nn.Module":
pass
class GPTQQuantHandler(QuantHandler):
"""
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each function.
get_qparams_func:
A function that calculates the quantization qparams for an input tensor.
Args:
weight: A 2d weight tensor with non-integer dtype.
Returns:
qparams: it can have any format but will need to be handled by the other defined functions below.
quantize_func:
A function that applies quantization to an input tensor. It should be noted
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
or a single column.
Args:
weight: A 2d weight tensor with non-integer dtype.
qparams: the output from get_qparams_func
Returns:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
dequantize_func:
A function that dequantizes an input quantized weight tensor. It should be noted
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
or a single column.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
weight: A 2d weight tensor with non-integer dtype.
combine_qparams_list_func:
A function that combines several qparams into one qparam.
Args:
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
on a single group from a weight tensor
Returns:
qparams: an object of the same format as the qparams above.
skip_layer_func:
A function that determines which linear layers should be skipped during GPTQ
Args:
weight: A 2d weight tensor with non-integer dtype.
Returns:
skip: boolean indicating whether layer should be skipped
make_names_and_values_dict_func:
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
corresponding quantized weights and qparams.
"""
def __init__(self):
assert self.mod is not None
assert self.get_qparams_func is not None
assert self.quantize_func is not None
assert self.dequantize_func is not None
assert self.combine_qparams_list_func is not None
assert self.make_names_and_values_dict_func is not None
@staticmethod
def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
input_recorder = InputRecorder(
model,
tokenizer,
calibration_seq_length,
pad_calibration_inputs,
)
try:
lm_eval.tasks.initialize_tasks()
except:
pass
task_dict = get_task_dict(calibration_tasks)
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
evaluate(
input_recorder,
task_dict,
limit=calibration_limit,
)
inputs = input_recorder.get_recorded_inputs()
assert inputs is not None, (
f"No inputs were collected, use a task other than {calibration_tasks}, "+
f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
f"{calibration_seq_length})"
)
print(f"Obtained {len(inputs[0].values)} calibration samples")
return inputs
@torch.no_grad()
def create_quantized_state_dict(
self,
tokenizer,
blocksize,
percdamp,
groupsize,
calibration_tasks,
calibration_limit,
calibration_seq_length,
pad_calibration_inputs,
) -> "StateDict":
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
print("Tracing model for GPTQ")
GPTQ_runner = GenericGPTQRunner(
self.mod,
inputs,
blocksize,
percdamp,
groupsize,
).configure_quantization_mode(
self.get_qparams_func,
self.quantize_func,
self.dequantize_func,
self.combine_qparams_list_func,
self.make_names_and_values_dict_func,
self.skip_layer_func
)
print("Applying GPTQ to weights")
GPTQ_runner.run()
return GPTQ_runner.get_quantized_state_dict()
def convert_for_runtime(self) -> "nn.Module":
pass
##### Weight-only int8 per-channel quantized code ######
def replace_linear_weight_only_int8_per_channel(module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
else:
replace_linear_weight_only_int8_per_channel(child)
class WeightOnlyInt8QuantHandler:
def __init__(self, mod):
self.mod = mod
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
cur_state_dict[f"{fqn}.weight"] = int8_weight
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
return cur_state_dict
def convert_for_runtime(self):
replace_linear_weight_only_int8_per_channel(self.mod)
return self.mod
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
##### weight only int4 per channel groupwise quantized code ######
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
))
elif padding:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
))
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
class WeightOnlyInt4QuantHandler:
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]
@torch.no_grad()
def create_quantized_state_dict(self, use_cuda = True):
if use_cuda:
device="cuda"
else:
device="cpu"
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")
weight = mod.weight.data
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
if self.padding:
from model import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
else:
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
return cur_state_dict
def convert_for_runtime(self):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
return self.mod
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
from model import find_multiple
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
self.quantize_func = lambda w, qparams: \
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
self.dequantize_func = lambda q, qparams: \
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
self.combine_qparams_list_func = lambda qparams_list: \
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
# skip unless padding=True or its correctly sized
self.skip_layer_func = lambda linear_weight: not (
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
)
# we need to do the padding here, both for q and the qparams if necessary
def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
new_k = find_multiple(k, 1024)
# how much we need to pad the weight
delta_k = new_k - q.shape[1]
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
scales_and_zeros = pack_scales_and_zeros(*qparams)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
self.make_names_and_values_dict_func = make_names_and_values_dict_func
super().__init__()
def convert_for_runtime(self):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
return self.mod
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self, in_features: int, out_features: int,
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
) -> None:
super().__init__()
self.padding = padding
if padding:
from model import find_multiple
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input,
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
def quantize(
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
mode: str = 'int8',
# following arguments only available when setting int4 quantization.
groupsize: int = 128,
# following arguments only used for GPTQ
calibration_tasks: list = ["hellaswag"],
calibration_limit: int = 1000,
calibration_seq_length: int = 100,
pad_calibration_inputs: bool = False,
percdamp: float = .01,
blocksize: int = 128,
label: str = '',
) -> None:
assert checkpoint_path.is_file(), checkpoint_path
device = 'cpu'
precision = torch.bfloat16
print("Loading model ...")
t0 = time.time()
with torch.device('meta'):
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
if mode == 'int8':
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
quant_handler = WeightOnlyInt8QuantHandler(model)
quantized_state_dict = quant_handler.create_quantized_state_dict()
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f'{label}int8.pth')
elif mode == 'int4':
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
quantized_state_dict = quant_handler.create_quantized_state_dict()
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
elif mode == 'int4-gptq':
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
quantized_state_dict = quant_handler.create_quantized_state_dict(
tokenizer,
blocksize,
percdamp,
groupsize,
calibration_tasks,
calibration_limit,
calibration_seq_length,
pad_calibration_inputs
)
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
else:
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
quantize_path = dir_name / new_base_name
print(f"Writing quantized weights to {quantize_path}")
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
torch.save(quantized_state_dict, quantize_path)
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
return
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Quantize a model.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
args = parser.parse_args()
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class ConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HIFI-GAN"""
import typing as tp
import numpy as np
from scipy.signal import get_window
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform
from torch import sin
from torch.nn.parameter import Parameter
"""hifigan based generator implementation.
This code is modified from https://github.com/jik876/hifi-gan
,https://github.com/kan-bayashi/ParallelWaveGAN and
https://github.com/NVIDIA/BigVGAN
"""
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
class ResBlock(torch.nn.Module):
"""Residual block module in HiFiGAN/BigVGAN."""
def __init__(
self,
channels: int = 512,
kernel_size: int = 3,
dilations: tp.List[int] = [1, 3, 5],
):
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
for dilation in dilations:
self.convs1.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
)
)
)
self.convs2.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
)
)
)
self.convs1.apply(init_weights)
self.convs2.apply(init_weights)
self.activations1 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs1))
])
self.activations2 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs2))
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for idx in range(len(self.convs1)):
xt = self.activations1[idx](x)
xt = self.convs1[idx](xt)
xt = self.activations2[idx](xt)
xt = self.convs2[idx](xt)
x = xt + x
return x
def remove_weight_norm(self):
for idx in range(len(self.convs1)):
remove_weight_norm(self.convs1[idx])
remove_weight_norm(self.convs2[idx])
class SineGen(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
@torch.no_grad()
def forward(self, f0):
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
# generate sine waveforms
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class HiFTGenerator(nn.Module):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: tp.List[int] = [8, 8],
upsample_kernel_sizes: tp.List[int] = [16, 16],
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
f0_predictor: torch.nn.Module = None,
):
super(HiFTGenerator, self).__init__()
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.f0_predictor = f0_predictor
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, _, _ = self.m_source(f0)
return har_source.transpose(1, 2)
def _stft(self, x):
spec = torch.stft(
x,
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[..., 0], spec[..., 1]
def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def forward(self, x: torch.Tensor, f0=None) -> torch.Tensor:
if f0 is None:
f0 = self.f0_predictor(x)
s = self._f02source(f0)
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
self.source_module.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
@torch.inference_mode()
def inference(self, mel: torch.Tensor, f0=None) -> torch.Tensor:
return self.forward(x=mel, f0=f0)
import math
import torch
from torch import nn
from typing import Optional, Any
from torch import Tensor
import torch.nn.functional as F
import torchaudio
import torchaudio.functional as audio_F
import random
random.seed(0)
def _get_activation_fn(activ):
if activ == 'relu':
return nn.ReLU()
elif activ == 'lrelu':
return nn.LeakyReLU(0.2)
elif activ == 'swish':
return lambda x: x*torch.sigmoid(x)
else:
raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class CausualConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
super(CausualConv, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2) * 2
else:
self.padding = padding * 2
self.conv = nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=self.padding,
dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
def forward(self, x):
x = self.conv(x)
x = x[:, :, :-self.padding]
return x
class CausualBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
super(CausualBlock, self).__init__()
self.blocks = nn.ModuleList([
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
for i in range(n_conv)])
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
layers = [
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
_get_activation_fn(activ),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(p=dropout_p),
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p)
]
return nn.Sequential(*layers)
class ConvBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
super().__init__()
self._n_groups = 8
self.blocks = nn.ModuleList([
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
for i in range(n_conv)])
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
layers = [
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
_get_activation_fn(activ),
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
nn.Dropout(p=dropout_p),
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p)
]
return nn.Sequential(*layers)
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size,
attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(2, attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding, bias=False, stride=1,
dilation=1)
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
bias=False, w_init_gain='tanh')
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size):
super(Attention, self).__init__()
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
bias=False, w_init_gain='tanh')
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
w_init_gain='tanh')
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory,
attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(torch.tanh(
processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self, attention_hidden_state, memory, processed_memory,
attention_weights_cat, mask):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class ForwardAttentionV2(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size):
super(ForwardAttentionV2, self).__init__()
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
bias=False, w_init_gain='tanh')
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
w_init_gain='tanh')
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self.score_mask_value = -float(1e20)
def get_alignment_energies(self, query, processed_memory,
attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(torch.tanh(
processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self, attention_hidden_state, memory, processed_memory,
attention_weights_cat, mask, log_alpha):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
log_energy = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat)
#log_energy =
if mask is not None:
log_energy.data.masked_fill_(mask, self.score_mask_value)
#attention_weights = F.softmax(alignment, dim=1)
#content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
#log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
#log_total_score = log_alpha + content_score
#previous_attention_weights = attention_weights_cat[:,0,:]
log_alpha_shift_padded = []
max_time = log_energy.size(1)
for sft in range(2):
shifted = log_alpha[:,:max_time-sft]
shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
log_alpha_new = biased + log_energy
attention_weights = F.softmax(log_alpha_new, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights, log_alpha_new
class PhaseShuffle2d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle2d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = torch.cat([right, left], dim=3)
return shuffled
class PhaseShuffle1d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle1d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :move]
right = x[:, :, move:]
shuffled = torch.cat([right, left], dim=2)
return shuffled
class MFCC(nn.Module):
def __init__(self, n_mfcc=40, n_mels=80):
super(MFCC, self).__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = 'ortho'
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer('dct_mat', dct_mat)
def forward(self, mel_specgram):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from indextts.s2mel.modules.commons import sequence_mask
import numpy as np
from indextts.s2mel.dac.nn.quantize import VectorQuantize
# f0_bin = 256
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
def f0_to_coarse(f0, f0_bin):
f0_mel = 1127 * (1 + f0 / 700).log()
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
b = f0_mel_min * a - 1.
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
# torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
f0_coarse = torch.round(f0_mel).long()
f0_coarse = f0_coarse * (f0_coarse > 0)
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
f0_coarse = f0_coarse * (f0_coarse < f0_bin)
f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
return f0_coarse
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
is_discrete: bool = False,
in_channels: int = None, # only applies to continuous input
vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
codebook_size: int = 1024, # for discrete only
out_channels: int = None,
groups: int = 1,
n_codebooks: int = 1, # number of codebooks
quantizer_dropout: float = 0.0, # dropout for quantizer
f0_condition: bool = False,
n_f0_bins: int = 512,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
self.interpolate = True
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
else:
self.interpolate = False
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
self.embedding = nn.Embedding(codebook_size, channels)
self.is_discrete = is_discrete
self.mask_token = nn.Parameter(torch.zeros(1, channels))
self.n_codebooks = n_codebooks
if n_codebooks > 1:
self.extra_codebooks = nn.ModuleList([
nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
])
self.extra_codebook_mask_tokens = nn.ParameterList([
nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
])
self.quantizer_dropout = quantizer_dropout
if f0_condition:
self.f0_embedding = nn.Embedding(n_f0_bins, channels)
self.f0_condition = f0_condition
self.n_f0_bins = n_f0_bins
self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
self.f0_mask = nn.Parameter(torch.zeros(1, channels))
else:
self.f0_condition = False
if not is_discrete:
self.content_in_proj = nn.Linear(in_channels, channels)
if vector_quantize:
self.vq = VectorQuantize(channels, codebook_size, 8)
def forward(self, x, ylens=None, n_quantizers=None, f0=None):
# apply token drop
if self.training:
n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
n_dropout = int(x.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(x.device)
# decide whether to drop for each sample in batch
else:
n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
if self.is_discrete:
if self.n_codebooks > 1:
assert len(x.size()) == 3
x_emb = self.embedding(x[:, 0])
for i, emb in enumerate(self.extra_codebooks):
x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
# add mask token if not using this codebook
# x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
x = x_emb
elif self.n_codebooks == 1:
if len(x.size()) == 2:
x = self.embedding(x)
else:
x = self.embedding(x[:, 0])
else:
x = self.content_in_proj(x)
# x in (B, T, D)
mask = sequence_mask(ylens).unsqueeze(-1)
if self.interpolate:
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
else:
x = x.transpose(1, 2).contiguous()
mask = mask[:, :x.size(2), :]
ylens = ylens.clamp(max=x.size(2)).long()
if self.f0_condition:
if f0 is None:
x = x + self.f0_mask.unsqueeze(-1)
else:
#quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
f0_emb = self.f0_embedding(quantized_f0)
f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
x = x + f0_emb
out = self.model(x).transpose(1, 2).contiguous()
if hasattr(self, 'vq'):
out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
out_q = out_q.transpose(1, 2)
return out_q * mask, ylens, codes, commitment_loss, codebook_loss
olens = ylens
return out * mask, olens, None, None, None
import torch
import numpy as np
import re
import soundfile
from . import utils
from . import commons
import os
import librosa
# from openvoice.text import text_to_sequence
from .mel_processing import spectrogram_torch
from .models import SynthesizerTrn
class OpenVoiceBaseClass(object):
def __init__(self,
config_path,
device='cuda:0'):
if 'cuda' in device:
assert torch.cuda.is_available()
hps = utils.get_hparams_from_file(config_path)
model = SynthesizerTrn(
len(getattr(hps, 'symbols', [])),
hps.data.filter_length // 2 + 1,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
model.eval()
self.model = model
self.hps = hps
self.device = device
def load_ckpt(self, ckpt_path):
checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
print("Loaded checkpoint '{}'".format(ckpt_path))
print('missing/unexpected keys:', a, b)
class BaseSpeakerTTS(OpenVoiceBaseClass):
language_marks = {
"english": "EN",
"chinese": "ZH",
}
@staticmethod
def get_text(text, hps, is_symbol):
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
@staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05)/speed)
audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments
@staticmethod
def split_sentences_into_pieces(text, language_str):
texts = utils.split_sentence(text, language_str=language_str)
print(" > Text splitted to sentences.")
print('\n'.join(texts))
print(" > ===========================")
return texts
def tts(self, text, output_path, speaker, language='English', speed=1.0):
mark = self.language_marks.get(language.lower(), None)
assert mark is not None, f"language {language} is not supported"
texts = self.split_sentences_into_pieces(text, mark)
audio_list = []
for t in texts:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
t = f'[{mark}]{t}[{mark}]'
stn_tst = self.get_text(t, self.hps, False)
device = self.device
speaker_id = self.hps.speakers[speaker]
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
sid = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
audio_list.append(audio)
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
if output_path is None:
return audio
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
class ToneColorConverter(OpenVoiceBaseClass):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# if kwargs.get('enable_watermark', True):
# import wavmark
# self.watermark_model = wavmark.load_model().to(self.device)
# else:
# self.watermark_model = None
self.version = getattr(self.hps, '_version_', "v1")
def extract_se(self, waves, wave_lengths):
device = self.device
hps = self.hps
gs = []
for wav_tensor, wav_len in zip(waves, wave_lengths):
y = wav_tensor[:wav_len]
y = y[None, :]
y = spectrogram_torch(y, hps.data.filter_length,
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
center=False).to(device)
with torch.no_grad():
g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
gs.append(g.detach())
gs = torch.stack(gs)
gs = gs.squeeze(1).squeeze(-1)
return gs
def convert(self, src_waves, src_wave_lengths, src_se, tgt_se, tau=0.3, message="default"):
hps = self.hps
# load audio
with torch.no_grad():
y = src_waves
spec = spectrogram_torch(y, hps.data.filter_length,
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
center=False).to(self.device)
spec_lengths = src_wave_lengths // hps.data.hop_length
spec_lengths = spec_lengths.clamp(min=1, max=spec.size(2))
audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se.unsqueeze(-1), sid_tgt=tgt_se.unsqueeze(-1), tau=tau)[0]
return audio
def add_watermark(self, audio, message):
# if self.watermark_model is None:
return audio
device = self.device
bits = utils.string_to_bits(message).reshape(-1)
n_repeat = len(bits) // 32
K = 16000
coeff = 2
for n in range(n_repeat):
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
if len(trunck) != K:
print('Audio too short, fail to add watermark')
break
message_npy = bits[n * 32: (n + 1) * 32]
with torch.no_grad():
signal = torch.FloatTensor(trunck).to(device)[None]
message_tensor = torch.FloatTensor(message_npy).to(device)[None]
signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
return audio
def detect_watermark(self, audio, n_repeat):
bits = []
K = 16000
coeff = 2
for n in range(n_repeat):
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
if len(trunck) != K:
print('Audio too short, fail to detect watermark')
return 'Fail'
with torch.no_grad():
signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
bits.append(message_decoded_npy)
bits = np.stack(bits).reshape(-1, 8)
message = utils.bits_to_string(bits)
return message
import math
import torch
from torch import nn
from torch.nn import functional as F
from . import commons
import logging
logger = logging.getLogger(__name__)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
isflow=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
# if isflow:
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
# self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256
self.cond_layer_idx = self.n_layers
if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = (
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
)
# logging.debug(self.gin_channels, self.cond_layer_idx)
assert (
self.cond_layer_idx < self.n_layers
), "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
if i == self.cond_layer_idx and g is not None:
g = self.spk_emb_linear(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# pad along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
{
"_version_": "v2",
"data": {
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_speakers": 0
},
"model": {
"zero_g": true,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
4,
4
],
"gin_channels": 256
}
}
\ No newline at end of file
import math
import torch
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
# if torch.min(y) < -1.1:
# print("min value is ", torch.min(y))
# if torch.max(y) > 1.1:
# print("max value is ", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
# if torch.min(y) < -1.:
# print('min value is ', torch.min(y))
# if torch.max(y) > 1.:
# print('max value is ', torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
# ******************** original ************************#
# y = y.squeeze(1)
# spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
# center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
# ******************** ConvSTFT ************************#
freq_cutoff = n_fft // 2 + 1
fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
import torch.nn.functional as F
# if center:
# signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
assert center is False
forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
# ******************** Verification ************************#
spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
assert torch.allclose(spec1, spec2, atol=1e-4)
spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec
\ No newline at end of file
import math
import torch
from torch import nn
from torch.nn import functional as F
from . import commons
from . import modules
from . import attentions
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from .commons import init_weights, get_padding
class TextEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None, tau=1.0):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for layer in self.ups:
remove_weight_norm(layer)
for layer in self.resblocks:
layer.remove_weight_norm()
class ReferenceEncoder(nn.Module):
"""
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
"""
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels)
if layernorm:
self.layernorm = nn.LayerNorm(self.spec_channels)
else:
self.layernorm = None
def forward(self, inputs, mask=None):
N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
if self.layernorm is not None:
out = self.layernorm(out)
for conv in self.convs:
out = conv(out)
# out = wn(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0))
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for i in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
n_vocab,
spec_channels,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=256,
gin_channels=256,
zero_g=False,
**kwargs
):
super().__init__()
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.n_speakers = n_speakers
if n_speakers == 0:
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
else:
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.zero_g = zero_g
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
+ self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src
g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return o_hat, y_mask, (z, z_p, z_hat)
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from . import commons
from .commons import init_weights, get_padding
from .transforms import piecewise_rational_quadratic_transform
from .attentions import Encoder
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dilated and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ConvFlow(nn.Module):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class TransformerCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout=0,
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels=0,
):
assert n_layers == 3, n_layers
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = (
Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
isflow=True,
gin_channels=gin_channels,
)
if wn_sharing_parameter is None
else wn_sharing_parameter
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
import os
import torch
import argparse
import gradio as gr
from zipfile import ZipFile
import langid
from . import se_extractor
from .api import BaseSpeakerTTS, ToneColorConverter
parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=False, help="make link public")
args = parser.parse_args()
en_ckpt_base = 'checkpoints/base_speakers/EN'
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
ckpt_converter = 'checkpoints/converter'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
# load models
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
# load speaker embeddings
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
# This online demo mainly supports English and Chinese
supported_languages = ['zh', 'en']
def predict(prompt, style, audio_file_pth, agree):
# initialize a empty info
text_hint = ''
# agree with the terms
if agree == False:
text_hint += '[ERROR] Please accept the Terms & Condition!\n'
gr.Warning("Please accept the Terms & Condition!")
return (
text_hint,
None,
None,
)
# first detect the input language
language_predicted = langid.classify(prompt)[0].strip()
print(f"Detected language:{language_predicted}")
if language_predicted not in supported_languages:
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
gr.Warning(
f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}"
)
return (
text_hint,
None,
None,
)
if language_predicted == "zh":
tts_model = zh_base_speaker_tts
source_se = zh_source_se
language = 'Chinese'
if style not in ['default']:
text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n"
gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']")
return (
text_hint,
None,
None,
)
else:
tts_model = en_base_speaker_tts
if style == 'default':
source_se = en_source_default_se
else:
source_se = en_source_style_se
language = 'English'
if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']")
return (
text_hint,
None,
None,
)
speaker_wav = audio_file_pth
if len(prompt) < 2:
text_hint += f"[ERROR] Please give a longer prompt text \n"
gr.Warning("Please give a longer prompt text")
return (
text_hint,
None,
None,
)
if len(prompt) > 200:
text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n"
gr.Warning(
"Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage"
)
return (
text_hint,
None,
None,
)
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
try:
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True)
except Exception as e:
text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
gr.Warning(
"[ERROR] Get target tone color error {str(e)} \n"
)
return (
text_hint,
None,
None,
)
src_path = f'{output_dir}/tmp.wav'
tts_model.tts(prompt, src_path, speaker=style, language=language)
save_path = f'{output_dir}/output.wav'
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
text_hint += f'''Get response successfully \n'''
return (
text_hint,
save_path,
speaker_wav,
)
title = "MyShell OpenVoice"
description = """
We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set.
"""
markdown_table = """
<div align="center" style="margin-bottom: 10px;">
| | | |
| :-----------: | :-----------: | :-----------: |
| **OpenSource Repo** | **Project Page** | **Join the Community** |
| <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | [OpenVoice](https://research.myshell.ai/open-voice) | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) |
</div>
"""
markdown_table_v2 = """
<div align="center" style="margin-bottom: 2px;">
| | | | |
| :-----------: | :-----------: | :-----------: | :-----------: |
| **OpenSource Repo** | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) |
| | |
| :-----------: | :-----------: |
**Join the Community** | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) |
</div>
"""
content = """
<div>
<strong>If the generated voice does not sound like the reference voice, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/docs/QA.md'>this QnA</a>.</strong> <strong>For multi-lingual & cross-lingual examples, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb'>this jupyter notebook</a>.</strong>
This online demo mainly supports <strong>English</strong>. The <em>default</em> style also supports <strong>Chinese</strong>. But OpenVoice can adapt to any other language as long as a base speaker is provided.
</div>
"""
wrapped_markdown_content = f"<div style='border: 1px solid #000; padding: 10px;'>{content}</div>"
examples = [
[
"今天天气真好,我们一起出去吃饭吧。",
'default',
"resources/demo_speaker1.mp3",
True,
],[
"This audio is generated by open voice with a half-performance model.",
'whispering',
"resources/demo_speaker2.mp3",
True,
],
[
"He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
'sad',
"resources/demo_speaker0.mp3",
True,
],
]
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown(
"""
## <img src="https://huggingface.co/spaces/myshell-ai/OpenVoice/raw/main/logo.jpg" height="40"/>
"""
)
with gr.Row():
gr.Markdown(markdown_table_v2)
with gr.Row():
gr.Markdown(description)
with gr.Column():
gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True)
with gr.Row():
gr.HTML(wrapped_markdown_content)
with gr.Row():
with gr.Column():
input_text_gr = gr.Textbox(
label="Text Prompt",
info="One or two sentences at a time is better. Up to 200 text characters.",
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
)
style_gr = gr.Dropdown(
label="Style",
info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)",
choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'],
max_choices=1,
value="default",
)
ref_gr = gr.Audio(
label="Reference Audio",
info="Click on the ✎ button to upload your own target speaker audio",
type="filepath",
value="resources/demo_speaker2.mp3",
)
tos_gr = gr.Checkbox(
label="Agree",
value=False,
info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE",
)
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
with gr.Column():
out_text_gr = gr.Text(label="Info")
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
ref_audio_gr = gr.Audio(label="Reference Audio Used")
gr.Examples(examples,
label="Examples",
inputs=[input_text_gr, style_gr, ref_gr, tos_gr],
outputs=[out_text_gr, audio_gr, ref_audio_gr],
fn=predict,
cache_examples=False,)
tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr])
demo.queue()
demo.launch(debug=True, show_api=True, share=args.share)
import os
import glob
import torch
import hashlib
import librosa
import base64
from glob import glob
import numpy as np
from pydub import AudioSegment
from faster_whisper import WhisperModel
import hashlib
import base64
import librosa
# from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
model_size = "medium"
# Run on GPU with FP16
model = None
def split_audio_whisper(audio_path, audio_name, target_dir='processed'):
global model
if model is None:
model = WhisperModel(model_size, device="cuda", compute_type="float16")
audio = AudioSegment.from_file(audio_path)
max_len = len(audio)
target_folder = os.path.join(target_dir, audio_name)
segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
segments = list(segments)
# create directory
os.makedirs(target_folder, exist_ok=True)
wavs_folder = os.path.join(target_folder, 'wavs')
os.makedirs(wavs_folder, exist_ok=True)
# segments
s_ind = 0
start_time = None
for k, w in enumerate(segments):
# process with the time
if k == 0:
start_time = max(0, w.start)
end_time = w.end
# calculate confidence
if len(w.words) > 0:
confidence = sum([s.probability for s in w.words]) / len(w.words)
else:
confidence = 0.
# clean text
text = w.text.replace('...', '')
# left 0.08s for each audios
audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
# segment file name
fname = f"{audio_name}_seg{s_ind}.wav"
# filter out the segment shorter than 1.5s and longer than 20s
save = audio_seg.duration_seconds > 1.5 and \
audio_seg.duration_seconds < 20. and \
len(text) >= 2 and len(text) < 200
if save:
output_file = os.path.join(wavs_folder, fname)
audio_seg.export(output_file, format='wav')
if k < len(segments) - 1:
start_time = max(0, segments[k+1].start - 0.08)
s_ind = s_ind + 1
return wavs_folder
def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0):
SAMPLE_RATE = 16000
audio_vad = get_audio_tensor(audio_path)
segments = get_vad_segments(
audio_vad,
output_sample=True,
min_speech_duration=0.1,
min_silence_duration=1,
method="silero",
)
segments = [(seg["start"], seg["end"]) for seg in segments]
segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
print(segments)
audio_active = AudioSegment.silent(duration=0)
audio = AudioSegment.from_file(audio_path)
for start_time, end_time in segments:
audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
audio_dur = audio_active.duration_seconds
print(f'after vad: dur = {audio_dur}')
target_folder = os.path.join(target_dir, audio_name)
wavs_folder = os.path.join(target_folder, 'wavs')
os.makedirs(wavs_folder, exist_ok=True)
start_time = 0.
count = 0
num_splits = int(np.round(audio_dur / split_seconds))
assert num_splits > 0, 'input audio is too short'
interval = audio_dur / num_splits
for i in range(num_splits):
end_time = min(start_time + interval, audio_dur)
if i == num_splits - 1:
end_time = audio_dur
output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
audio_seg.export(output_file, format='wav')
start_time = end_time
count += 1
return wavs_folder
def hash_numpy_array(audio_path):
array, _ = librosa.load(audio_path, sr=None, mono=True)
# Convert the array to bytes
array_bytes = array.tobytes()
# Calculate the hash of the array bytes
hash_object = hashlib.sha256(array_bytes)
hash_value = hash_object.digest()
# Convert the hash value to base64
base64_value = base64.b64encode(hash_value)
return base64_value.decode('utf-8')[:16].replace('/', '_^')
def get_se(audio_path, vc_model, target_dir='processed', vad=True):
device = vc_model.device
version = vc_model.version
print("OpenVoice version:", version)
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}"
se_path = os.path.join(target_dir, audio_name, 'se.pth')
# if os.path.isfile(se_path):
# se = torch.load(se_path).to(device)
# return se, audio_name
# if os.path.isdir(audio_path):
# wavs_folder = audio_path
# if vad:
# wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
# else:
# wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name)
# audio_segs = glob(f'{wavs_folder}/*.wav')
# if len(audio_segs) == 0:
# raise NotImplementedError('No audio segments found!')
return vc_model.extract_se([audio_path], se_save_path=se_path), audio_name
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet
import re
import json
import numpy as np
def get_hparams_from_file(config_path):
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
def string_to_bits(string, pad_len=8):
# Convert each character to its ASCII value
ascii_values = [ord(char) for char in string]
# Convert ASCII values to binary representation
binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
# Convert binary strings to integer arrays
bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
# Convert list of arrays to NumPy array
numpy_array = np.array(bit_arrays)
numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
numpy_array_full[:, 2] = 1
max_len = min(pad_len, len(numpy_array))
numpy_array_full[:max_len] = numpy_array[:max_len]
return numpy_array_full
def bits_to_string(bits_array):
# Convert each row of the array to a binary string
binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
# Convert binary strings to ASCII values
ascii_values = [int(binary, 2) for binary in binary_values]
# Convert ASCII values to characters
output_string = ''.join(chr(value) for value in ascii_values)
return output_string
def split_sentence(text, min_len=10, language_str='[EN]'):
if language_str in ['EN']:
sentences = split_sentences_latin(text, min_len=min_len)
else:
sentences = split_sentences_zh(text, min_len=min_len)
return sentences
def split_sentences_latin(text, min_len=10):
"""Split Long sentences into list of short ones
Args:
str: Input sentences.
Returns:
List[str]: list of output sentences.
"""
# deal with dirty sentences
text = re.sub('[。!?;]', '.', text)
text = re.sub('[,]', ',', text)
text = re.sub('[“”]', '"', text)
text = re.sub('[‘’]', "'", text)
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
text = re.sub('[\n\t ]+', ' ', text)
text = re.sub('([,.!?;])', r'\1 $#!', text)
# split
sentences = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1]
new_sentences = []
new_sent = []
count_len = 0
for ind, sent in enumerate(sentences):
# print(sent)
new_sent.append(sent)
count_len += len(sent.split(" "))
if count_len > min_len or ind == len(sentences) - 1:
count_len = 0
new_sentences.append(' '.join(new_sent))
new_sent = []
return merge_short_sentences_latin(new_sentences)
def merge_short_sentences_latin(sens):
"""Avoid short sentences by merging them with the following sentence.
Args:
List[str]: list of input sentences.
Returns:
List[str]: list of output sentences.
"""
sens_out = []
for s in sens:
# If the previous sentence is too short, merge them with
# the current sentence.
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
sens_out[-1] = sens_out[-1] + " " + s
else:
sens_out.append(s)
try:
if len(sens_out[-1].split(" ")) <= 2:
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
sens_out.pop(-1)
except:
pass
return sens_out
def split_sentences_zh(text, min_len=10):
text = re.sub('[。!?;]', '.', text)
text = re.sub('[,]', ',', text)
# 将文本中的换行符、空格和制表符替换为空格
text = re.sub('[\n\t ]+', ' ', text)
# 在标点符号后添加一个空格
text = re.sub('([,.!?;])', r'\1 $#!', text)
# 分隔句子并去除前后空格
# sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
sentences = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1]
new_sentences = []
new_sent = []
count_len = 0
for ind, sent in enumerate(sentences):
new_sent.append(sent)
count_len += len(sent)
if count_len > min_len or ind == len(sentences) - 1:
count_len = 0
new_sentences.append(' '.join(new_sent))
new_sent = []
return merge_short_sentences_zh(new_sentences)
def merge_short_sentences_zh(sens):
# return sens
"""Avoid short sentences by merging them with the following sentence.
Args:
List[str]: list of input sentences.
Returns:
List[str]: list of output sentences.
"""
sens_out = []
for s in sens:
# If the previous sentense is too short, merge them with
# the current sentence.
if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
sens_out[-1] = sens_out[-1] + " " + s
else:
sens_out.append(s)
try:
if len(sens_out[-1]) <= 2:
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
sens_out.pop(-1)
except:
pass
return sens_out
\ No newline at end of file
from dac.nn.quantize import ResidualVectorQuantize
from torch import nn
from modules.wavenet import WN
import torch
import torchaudio
import torchaudio.functional as audio_F
import numpy as np
from .alias_free_torch import *
from torch.nn.utils import weight_norm
from torch import nn, sin, pow
from einops.layers.torch import Rearrange
from dac.model.encodec import SConv1d
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta := x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
return x + self.block(x)
class CNNLSTM(nn.Module):
def __init__(self, indim, outdim, head, global_pred=False):
super().__init__()
self.global_pred = global_pred
self.model = nn.Sequential(
ResidualUnit(indim, dilation=1),
ResidualUnit(indim, dilation=2),
ResidualUnit(indim, dilation=3),
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
Rearrange("b c t -> b t c"),
)
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
def forward(self, x):
# x: [B, C, T]
x = self.model(x)
if self.global_pred:
x = torch.mean(x, dim=1, keepdim=False)
outs = [head(x) for head in self.heads]
return outs
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
class FAquantizer(nn.Module):
def __init__(self, in_dim=1024,
n_p_codebooks=1,
n_c_codebooks=2,
n_t_codebooks=2,
n_r_codebooks=3,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0.5,
causal=False,
separate_prosody_encoder=False,
timbre_norm=False,):
super(FAquantizer, self).__init__()
conv1d_type = SConv1d# if causal else nn.Conv1d
self.prosody_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_p_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.content_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_c_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.residual_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_r_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal)
self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal)
self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal)
self.prob_random_mask_residual = 0.75
SPECT_PARAMS = {
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
}
MEL_PARAMS = {
"n_mels": 80,
}
self.to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
)
self.mel_mean, self.mel_std = -4, 4
self.frame_rate = 24000 / 300
self.hop_length = 300
def preprocess(self, wave_tensor, n_bins=20):
mel_tensor = self.to_mel(wave_tensor.squeeze(1))
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)]
def forward(self, x, wave_segments):
outs = 0
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
f0_input.device).bool())
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
f0_input, 1
)
outs += z_p.detach()
z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
x, 2
)
outs += z_c.detach()
residual_feature = x - z_p.detach() - z_c.detach()
z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
residual_feature, 3
)
quantized = [z_p, z_c, z_r]
codes = [codes_p, codes_c, codes_r]
return quantized, codes
\ No newline at end of file
from io import BytesIO
import os
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from librosa.util import normalize, pad_center, tiny
from scipy.signal import get_window
import logging
logger = logging.getLogger(__name__)
class STFT(torch.nn.Module):
def __init__(
self, filter_length=1024, hop_length=512, win_length=None, window="hann"
):
"""
This module implements an STFT using 1D convolution and 1D transpose convolutions.
This is a bit tricky so there are some cases that probably won't work as working
out the same sizes before and after in all overlap add setups is tough. Right now,
this code should work with hop lengths that are half the filter length (50% overlap
between frames).
Keyword Arguments:
filter_length {int} -- Length of filters used (default: {1024})
hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
equals the filter length). (default: {None})
window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
(default: {'hann'})
"""
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length if win_length else filter_length
self.window = window
self.forward_transform = None
self.pad_amount = int(self.filter_length / 2)
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
forward_basis = torch.FloatTensor(fourier_basis)
inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
assert filter_length >= self.win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, self.win_length, fftbins=True)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis = (inverse_basis.T * fft_window).T
self.register_buffer("forward_basis", forward_basis.float())
self.register_buffer("inverse_basis", inverse_basis.float())
self.register_buffer("fft_window", fft_window.float())
def transform(self, input_data, return_phase=False):
"""Take input data (audio) to STFT domain.
Arguments:
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
Returns:
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
num_frequencies, num_frames)
phase {tensor} -- Phase of STFT with shape (num_batch,
num_frequencies, num_frames)
"""
input_data = F.pad(
input_data,
(self.pad_amount, self.pad_amount),
mode="reflect",
)
forward_transform = input_data.unfold(
1, self.filter_length, self.hop_length
).permute(0, 2, 1)
forward_transform = torch.matmul(self.forward_basis, forward_transform)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
if return_phase:
phase = torch.atan2(imag_part.data, real_part.data)
return magnitude, phase
else:
return magnitude
def inverse(self, magnitude, phase):
"""Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
by the ```transform``` function.
Arguments:
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
num_frequencies, num_frames)
phase {tensor} -- Phase of STFT with shape (num_batch,
num_frequencies, num_frames)
Returns:
inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
shape (num_batch, num_samples)
"""
cat = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
fold = torch.nn.Fold(
output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
kernel_size=(1, self.filter_length),
stride=(1, self.hop_length),
)
inverse_transform = torch.matmul(self.inverse_basis, cat)
inverse_transform = fold(inverse_transform)[
:, 0, 0, self.pad_amount : -self.pad_amount
]
window_square_sum = (
self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
)
window_square_sum = fold(window_square_sum)[
:, 0, 0, self.pad_amount : -self.pad_amount
]
inverse_transform /= window_square_sum
return inverse_transform
def forward(self, input_data):
"""Take input data (audio) to STFT domain and then back to audio.
Arguments:
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
Returns:
reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
shape (num_batch, num_samples)
"""
self.magnitude, self.phase = self.transform(input_data, return_phase=True)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
from time import time as ttime
class BiGRU(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
super(BiGRU, self).__init__()
self.gru = nn.GRU(
input_features,
hidden_features,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
def forward(self, x):
return self.gru(x)[0]
class ConvBlockRes(nn.Module):
def __init__(self, in_channels, out_channels, momentum=0.01):
super(ConvBlockRes, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
)
# self.shortcut:Optional[nn.Module] = None
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
def forward(self, x: torch.Tensor):
if not hasattr(self, "shortcut"):
return self.conv(x) + x
else:
return self.conv(x) + self.shortcut(x)
class Encoder(nn.Module):
def __init__(
self,
in_channels,
in_size,
n_encoders,
kernel_size,
n_blocks,
out_channels=16,
momentum=0.01,
):
super(Encoder, self).__init__()
self.n_encoders = n_encoders
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
self.layers = nn.ModuleList()
self.latent_channels = []
for i in range(self.n_encoders):
self.layers.append(
ResEncoderBlock(
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
)
)
self.latent_channels.append([out_channels, in_size])
in_channels = out_channels
out_channels *= 2
in_size //= 2
self.out_size = in_size
self.out_channel = out_channels
def forward(self, x: torch.Tensor):
concat_tensors: List[torch.Tensor] = []
x = self.bn(x)
for i, layer in enumerate(self.layers):
t, x = layer(x)
concat_tensors.append(t)
return x, concat_tensors
class ResEncoderBlock(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
):
super(ResEncoderBlock, self).__init__()
self.n_blocks = n_blocks
self.conv = nn.ModuleList()
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
for i in range(n_blocks - 1):
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
self.kernel_size = kernel_size
if self.kernel_size is not None:
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
def forward(self, x):
for i, conv in enumerate(self.conv):
x = conv(x)
if self.kernel_size is not None:
return x, self.pool(x)
else:
return x
class Intermediate(nn.Module): #
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
super(Intermediate, self).__init__()
self.n_inters = n_inters
self.layers = nn.ModuleList()
self.layers.append(
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
)
for i in range(self.n_inters - 1):
self.layers.append(
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
return x
class ResDecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
super(ResDecoderBlock, self).__init__()
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
self.n_blocks = n_blocks
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=stride,
padding=(1, 1),
output_padding=out_padding,
bias=False,
),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
)
self.conv2 = nn.ModuleList()
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
for i in range(n_blocks - 1):
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
def forward(self, x, concat_tensor):
x = self.conv1(x)
x = torch.cat((x, concat_tensor), dim=1)
for i, conv2 in enumerate(self.conv2):
x = conv2(x)
return x
class Decoder(nn.Module):
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
super(Decoder, self).__init__()
self.layers = nn.ModuleList()
self.n_decoders = n_decoders
for i in range(self.n_decoders):
out_channels = in_channels // 2
self.layers.append(
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
)
in_channels = out_channels
def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
for i, layer in enumerate(self.layers):
x = layer(x, concat_tensors[-1 - i])
return x
class DeepUnet(nn.Module):
def __init__(
self,
kernel_size,
n_blocks,
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super(DeepUnet, self).__init__()
self.encoder = Encoder(
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
)
self.intermediate = Intermediate(
self.encoder.out_channel // 2,
self.encoder.out_channel,
inter_layers,
n_blocks,
)
self.decoder = Decoder(
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, concat_tensors = self.encoder(x)
x = self.intermediate(x)
x = self.decoder(x, concat_tensors)
return x
class E2E(nn.Module):
def __init__(
self,
n_blocks,
n_gru,
kernel_size,
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super(E2E, self).__init__()
self.unet = DeepUnet(
kernel_size,
n_blocks,
en_de_layers,
inter_layers,
in_channels,
en_out_channels,
)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * 128, 256, n_gru),
nn.Linear(512, 360),
nn.Dropout(0.25),
nn.Sigmoid(),
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
)
def forward(self, mel):
# print(mel.shape)
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
# print(x.shape)
return x
from librosa.filters import mel
class MelSpectrogram(torch.nn.Module):
def __init__(
self,
is_half,
n_mel_channels,
sampling_rate,
win_length,
hop_length,
n_fft=None,
mel_fmin=0,
mel_fmax=None,
clamp=1e-5,
):
super().__init__()
n_fft = win_length if n_fft is None else n_fft
self.hann_window = {}
mel_basis = mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
htk=True,
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.n_fft = win_length if n_fft is None else n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
self.clamp = clamp
self.is_half = is_half
def forward(self, audio, keyshift=0, speed=1, center=True):
factor = 2 ** (keyshift / 12)
n_fft_new = int(np.round(self.n_fft * factor))
win_length_new = int(np.round(self.win_length * factor))
hop_length_new = int(np.round(self.hop_length * speed))
keyshift_key = str(keyshift) + "_" + str(audio.device)
if keyshift_key not in self.hann_window:
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
audio.device
)
if "privateuseone" in str(audio.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=n_fft_new,
hop_length=hop_length_new,
win_length=win_length_new,
window="hann",
).to(audio.device)
magnitude = self.stft.transform(audio)
else:
fft = torch.stft(
audio,
n_fft=n_fft_new,
hop_length=hop_length_new,
win_length=win_length_new,
window=self.hann_window[keyshift_key],
center=center,
return_complex=True,
)
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
if keyshift != 0:
size = self.n_fft // 2 + 1
resize = magnitude.size(1)
if resize < size:
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
mel_output = torch.matmul(self.mel_basis, magnitude)
if self.is_half == True:
mel_output = mel_output.half()
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
return log_mel_spec
class RMVPE:
def __init__(self, model_path: str, is_half, device=None, use_jit=False):
self.resample_kernel = {}
self.resample_kernel = {}
self.is_half = is_half
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = device
self.mel_extractor = MelSpectrogram(
is_half, 128, 16000, 1024, 160, None, 30, 8000
).to(device)
if "privateuseone" in str(device):
import onnxruntime as ort
ort_session = ort.InferenceSession(
"%s/rmvpe.onnx" % os.environ["rmvpe_root"],
providers=["DmlExecutionProvider"],
)
self.model = ort_session
else:
if str(self.device) == "cuda":
self.device = torch.device("cuda:0")
def get_default_model():
model = E2E(4, 1, (2, 2))
ckpt = torch.load(model_path, map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
if is_half:
model = model.half()
else:
model = model.float()
return model
self.model = get_default_model()
self.model = self.model.to(device)
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
def mel2hidden(self, mel):
with torch.no_grad():
n_frames = mel.shape[-1]
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
if n_pad > 0:
mel = F.pad(mel, (0, n_pad), mode="constant")
if "privateuseone" in str(self.device):
onnx_input_name = self.model.get_inputs()[0].name
onnx_outputs_names = self.model.get_outputs()[0].name
hidden = self.model.run(
[onnx_outputs_names],
input_feed={onnx_input_name: mel.cpu().numpy()},
)[0]
else:
mel = mel.half() if self.is_half else mel.float()
hidden = self.model(mel)
return hidden[:, :n_frames]
def decode(self, hidden, thred=0.03):
cents_pred = self.to_local_average_cents(hidden, thred=thred)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
return f0
def infer_from_audio(self, audio, thred=0.03):
# torch.cuda.synchronize()
# t0 = ttime()
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
mel = self.mel_extractor(
audio.float().to(self.device).unsqueeze(0), center=True
)
# print(123123123,mel.device.type)
# torch.cuda.synchronize()
# t1 = ttime()
hidden = self.mel2hidden(mel)
# torch.cuda.synchronize()
# t2 = ttime()
# print(234234,hidden.device.type)
if "privateuseone" not in str(self.device):
hidden = hidden.squeeze(0).cpu().numpy()
else:
hidden = hidden[0]
if self.is_half == True:
hidden = hidden.astype("float32")
f0 = self.decode(hidden, thred=thred)
# torch.cuda.synchronize()
# t3 = ttime()
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
return f0
def infer_from_audio_batch(self, audio, thred=0.03):
# torch.cuda.synchronize()
# t0 = ttime()
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
mel = self.mel_extractor(
audio.float().to(self.device), center=True
)
# print(123123123,mel.device.type)
# torch.cuda.synchronize()
# t1 = ttime()
hidden = self.mel2hidden(mel)
# torch.cuda.synchronize()
# t2 = ttime()
# print(234234,hidden.device.type)
if "privateuseone" not in str(self.device):
hidden = hidden.cpu().numpy()
else:
pass
if self.is_half == True:
hidden = hidden.astype("float32")
f0s = []
for bib in range(hidden.shape[0]):
f0s.append(self.decode(hidden[bib], thred=thred))
f0s = np.stack(f0s)
f0s = torch.from_numpy(f0s).to(self.device)
# torch.cuda.synchronize()
# t3 = ttime()
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
return f0s
def to_local_average_cents(self, salience, thred=0.05):
# t0 = ttime()
center = np.argmax(salience, axis=1) # 帧长#index
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
# t1 = ttime()
center += 4
todo_salience = []
todo_cents_mapping = []
starts = center - 4
ends = center + 5
for idx in range(salience.shape[0]):
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
# t2 = ttime()
todo_salience = np.array(todo_salience) # 帧长,9
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
weight_sum = np.sum(todo_salience, 1) # 帧长
devided = product_sum / weight_sum # 帧长
# t3 = ttime()
maxx = np.max(salience, axis=1) # 帧长
devided[maxx <= thred] = 0
# t4 = ttime()
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
return devided
from .pretrained import Vocos
__version__ = "0.1.0"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment