Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
import torch
import torchaudio
from torch import nn
from indextts.utils.common import safe_log
class FeatureExtractor(nn.Module):
"""Base class for feature extractors."""
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Extract features from the given audio.
Args:
audio (Tensor): Input audio waveform.
Returns:
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None,
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=1,
normalized=normalize,
f_min=mel_fmin,
f_max=mel_fmax,
n_mels=n_mels,
center=padding == "center",
)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
mel = safe_log(mel)
return mel
# -*- coding: utf-8 -*-
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
import os
import traceback
import re
from typing import List, Union, overload
import warnings
from indextts.utils.common import tokenize_by_CJK_char, de_tokenized_by_CJK_char
from sentencepiece import SentencePieceProcessor
class TextNormalizer:
def __init__(self):
self.zh_normalizer = None
self.en_normalizer = None
self.char_rep_map = {
":": ",",
";": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": " ",
"·": "-",
"、": ",",
"...": "…",
",,,": "…",
",,,": "…",
"……": "…",
"“": "'",
"”": "'",
'"': "'",
"‘": "'",
"’": "'",
"(": "'",
")": "'",
"(": "'",
")": "'",
"《": "'",
"》": "'",
"【": "'",
"】": "'",
"[": "'",
"]": "'",
"—": "-",
"~": "-",
"~": "-",
"「": "'",
"」": "'",
":": ",",
}
self.zh_char_rep_map = {
"$": ".",
**self.char_rep_map,
}
def match_email(self, email):
# 正则表达式匹配邮箱格式:数字英文@数字英文.英文
pattern = r"^[a-zA-Z0-9]+@[a-zA-Z0-9]+\.[a-zA-Z]+$"
return re.match(pattern, email) is not None
PINYIN_TONE_PATTERN = r"(?<![a-z])((?:[bpmfdtnlgkhjqxzcsryw]|[zcs]h)?(?:[aeiouüv]|[ae]i|u[aio]|ao|ou|i[aue]|[uüv]e|[uvü]ang?|uai|[aeiuv]n|[aeio]ng|ia[no]|i[ao]ng)|ng|er)([1-5])"
"""
匹配拼音声调格式:pinyin+数字,声调1-5,5表示轻声
例如:xuan4, jve2, ying1, zhong4, shang5
不匹配:beta1, voice2
"""
NAME_PATTERN = r"[\u4e00-\u9fff]+(?:[-·—][\u4e00-\u9fff]+){1,2}"
"""
匹配人名,格式:中文·中文,中文·中文-中文
例如:克里斯托弗·诺兰,约瑟夫·高登-莱维特
"""
# 匹配常见英语缩写 's,仅用于替换为 is,不匹配所有 's
ENGLISH_CONTRACTION_PATTERN = r"(what|where|who|which|how|t?here|it|s?he|that|this)'s"
def use_chinese(self, s):
has_chinese = bool(re.search(r"[\u4e00-\u9fff]", s))
has_alpha = bool(re.search(r"[a-zA-Z]", s))
is_email = self.match_email(s)
if has_chinese or not has_alpha or is_email:
return True
has_pinyin = bool(re.search(TextNormalizer.PINYIN_TONE_PATTERN, s, re.IGNORECASE))
return has_pinyin
def load(self):
# print(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
# sys.path.append(model_dir)
import platform
if self.zh_normalizer is not None and self.en_normalizer is not None:
return
if platform.system() == "Darwin" or platform.system() == "Windows":
from wetext import Normalizer
self.zh_normalizer = Normalizer(remove_erhua=False, lang="zh", operator="tn")
self.en_normalizer = Normalizer(lang="en", operator="tn")
else:
from tn.chinese.normalizer import Normalizer as NormalizerZh
from tn.english.normalizer import Normalizer as NormalizerEn
# use new cache dir for build tagger rules with disable remove_interjections and remove_erhua
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tagger_cache")
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
with open(os.path.join(cache_dir, ".gitignore"), "w") as f:
f.write("*\n")
self.zh_normalizer = NormalizerZh(
cache_dir=cache_dir, remove_interjections=False, remove_erhua=False, overwrite_cache=False
)
self.en_normalizer = NormalizerEn(overwrite_cache=False)
def normalize(self, text: str) -> str:
if not self.zh_normalizer or not self.en_normalizer:
print("Error, text normalizer is not initialized !!!")
return ""
if self.use_chinese(text):
text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE)
replaced_text, pinyin_list = self.save_pinyin_tones(text.rstrip())
replaced_text, original_name_list = self.save_names(replaced_text)
try:
result = self.zh_normalizer.normalize(replaced_text)
except Exception:
result = ""
print(traceback.format_exc())
# 恢复人名
result = self.restore_names(result, original_name_list)
# 恢复拼音声调
result = self.restore_pinyin_tones(result, pinyin_list)
pattern = re.compile("|".join(re.escape(p) for p in self.zh_char_rep_map.keys()))
result = pattern.sub(lambda x: self.zh_char_rep_map[x.group()], result)
else:
try:
text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE)
result = self.en_normalizer.normalize(text)
except Exception:
result = text
print(traceback.format_exc())
pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys()))
result = pattern.sub(lambda x: self.char_rep_map[x.group()], result)
return result
def correct_pinyin(self, pinyin: str):
"""
将 jqx 的韵母为 u/ü 的拼音转换为 v
如:ju -> jv , que -> qve, xün -> xvn
"""
if pinyin[0] not in "jqxJQX":
return pinyin
# 匹配 jqx 的韵母为 u/ü 的拼音
pattern = r"([jqx])[uü](n|e|an)*(\d)"
repl = r"\g<1>v\g<2>\g<3>"
pinyin = re.sub(pattern, repl, pinyin, flags=re.IGNORECASE)
return pinyin.upper()
def save_names(self, original_text):
"""
替换人名为占位符 <n_a>、 <n_b>, ...
例如:克里斯托弗·诺兰 -> <n_a>
"""
# 人名
name_pattern = re.compile(TextNormalizer.NAME_PATTERN, re.IGNORECASE)
original_name_list = re.findall(name_pattern, original_text)
if len(original_name_list) == 0:
return (original_text, None)
original_name_list = list(set("".join(n) for n in original_name_list))
transformed_text = original_text
# 替换占位符 <n_a>、 <n_b>, ...
for i, name in enumerate(original_name_list):
number = chr(ord("a") + i)
transformed_text = transformed_text.replace(name, f"<n_{number}>")
return transformed_text, original_name_list
def restore_names(self, normalized_text, original_name_list):
"""
恢复人名为原来的文字
例如:<n_a> -> original_name_list[0]
"""
if not original_name_list or len(original_name_list) == 0:
return normalized_text
transformed_text = normalized_text
# 替换为占位符 <n_a>、 <n_b>, ...
for i, name in enumerate(original_name_list):
number = chr(ord("a") + i)
transformed_text = transformed_text.replace(f"<n_{number}>", name)
return transformed_text
def save_pinyin_tones(self, original_text):
"""
替换拼音声调为占位符 <pinyin_a>, <pinyin_b>, ...
例如:xuan4 -> <pinyin_a>
"""
# 声母韵母+声调数字
origin_pinyin_pattern = re.compile(TextNormalizer.PINYIN_TONE_PATTERN, re.IGNORECASE)
original_pinyin_list = re.findall(origin_pinyin_pattern, original_text)
if len(original_pinyin_list) == 0:
return (original_text, None)
original_pinyin_list = list(set("".join(p) for p in original_pinyin_list))
transformed_text = original_text
# 替换为占位符 <pinyin_a>, <pinyin_b>, ...
for i, pinyin in enumerate(original_pinyin_list):
number = chr(ord("a") + i)
transformed_text = transformed_text.replace(pinyin, f"<pinyin_{number}>")
# print("original_text: ", original_text)
# print("transformed_text: ", transformed_text)
return transformed_text, original_pinyin_list
def restore_pinyin_tones(self, normalized_text, original_pinyin_list):
"""
恢复拼音中的音调数字(1-5)为原来的拼音
例如:<pinyin_a> -> original_pinyin_list[0]
"""
if not original_pinyin_list or len(original_pinyin_list) == 0:
return normalized_text
transformed_text = normalized_text
# 替换占位符 <pinyin_a>, <pinyin_b>, ...
for i, pinyin in enumerate(original_pinyin_list):
number = chr(ord("a") + i)
pinyin = self.correct_pinyin(pinyin)
transformed_text = transformed_text.replace(f"<pinyin_{number}>", pinyin)
# print("normalized_text: ", normalized_text)
# print("transformed_text: ", transformed_text)
return transformed_text
class TextTokenizer:
def __init__(self, vocab_file: str, normalizer: TextNormalizer = None):
self.vocab_file = vocab_file
self.normalizer = normalizer
if self.vocab_file is None:
raise ValueError("vocab_file is None")
if not os.path.exists(self.vocab_file):
raise ValueError(f"vocab_file {self.vocab_file} does not exist")
if self.normalizer:
self.normalizer.load()
# 加载词表
self.sp_model = SentencePieceProcessor(model_file=self.vocab_file)
self.pre_tokenizers = [
# 预处理器
tokenize_by_CJK_char,
]
@property
def vocab_size(self):
return self.sp_model.GetPieceSize()
@property
def unk_token(self):
return "<unk>"
@property
def pad_token(self):
return None
@property
def bos_token(self):
return "<s>"
@property
def eos_token(self):
return "</s>"
@property
def pad_token_id(self):
return -1
@property
def bos_token_id(self):
return 0
@property
def eos_token_id(self):
return 1
@property
def unk_token_id(self):
return self.sp_model.unk_id()
@property
def special_tokens_map(self):
return {
"unk_token": self.unk_token,
"pad_token": self.pad_token,
"bos_token": self.bos_token,
"eos_token": self.eos_token,
}
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
return vocab
@overload
def convert_ids_to_tokens(self, ids: int) -> str: ...
@overload
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: ...
def convert_ids_to_tokens(self, ids: Union[List[int], int]):
return self.sp_model.IdToPiece(ids)
def convert_tokens_to_ids(self, tokens: Union[List[str], str]) -> List[int]:
if isinstance(tokens, str):
tokens = [tokens]
return [self.sp_model.PieceToId(token) for token in tokens]
def tokenize(self, text: str) -> List[str]:
return self.encode(text, out_type=str)
def encode(self, text: str, **kwargs):
if len(text) == 0:
return []
if len(text.strip()) == 1:
return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs)
# 预处理
if self.normalizer:
text = self.normalizer.normalize(text)
if len(self.pre_tokenizers) > 0:
for pre_tokenizer in self.pre_tokenizers:
text = pre_tokenizer(text)
return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs)
def batch_encode(self, texts: List[str], **kwargs):
# 预处理
if self.normalizer:
texts = [self.normalizer.normalize(text) for text in texts]
if len(self.pre_tokenizers) > 0:
for pre_tokenizer in self.pre_tokenizers:
texts = [pre_tokenizer(text) for text in texts]
return self.sp_model.Encode(texts, out_type=kwargs.pop("out_type", int), **kwargs)
def decode(self, ids: Union[List[int], int], do_lower_case=False, **kwargs):
if isinstance(ids, int):
ids = [ids]
decoded = self.sp_model.Decode(ids, out_type=kwargs.pop("out_type", str), **kwargs)
return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case)
@staticmethod
def split_sentences_by_token(
tokenized_str: List[str], split_tokens: List[str], max_tokens_per_sentence: int
) -> List[List[str]]:
"""
将tokenize后的结果按特定token进一步分割
"""
# 处理特殊情况
if len(tokenized_str) == 0:
return []
sentences: List[List[str]] = []
current_sentence = []
current_sentence_tokens_len = 0
for i in range(len(tokenized_str)):
token = tokenized_str[i]
current_sentence.append(token)
current_sentence_tokens_len += 1
if current_sentence_tokens_len <= max_tokens_per_sentence:
if token in split_tokens and current_sentence_tokens_len > 2:
if i < len(tokenized_str) - 1:
if tokenized_str[i + 1] in ["'", "▁'"]:
# 后续token是',则不切分
current_sentence.append(tokenized_str[i + 1])
i += 1
sentences.append(current_sentence)
current_sentence = []
current_sentence_tokens_len = 0
continue
# 如果当前tokens的长度超过最大限制
if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence):
# 如果当前tokens中有,,则按,分割
sub_sentences = TextTokenizer.split_sentences_by_token(
current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence
)
elif "-" not in split_tokens and "-" in current_sentence:
# 没有,,则按-分割
sub_sentences = TextTokenizer.split_sentences_by_token(
current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence
)
else:
# 按照长度分割
sub_sentences = []
for j in range(0, len(current_sentence), max_tokens_per_sentence):
if j + max_tokens_per_sentence < len(current_sentence):
sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence])
else:
sub_sentences.append(current_sentence[j:])
warnings.warn(
f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, "
f"Tokens in sentence: {current_sentence}."
"Maybe unexpected behavior",
RuntimeWarning,
)
sentences.extend(sub_sentences)
current_sentence = []
current_sentence_tokens_len = 0
if current_sentence_tokens_len > 0:
assert current_sentence_tokens_len <= max_tokens_per_sentence
sentences.append(current_sentence)
# 如果相邻的句子加起来长度小于最大限制,则合并
merged_sentences = []
for sentence in sentences:
if len(sentence) == 0:
continue
if len(merged_sentences) == 0:
merged_sentences.append(sentence)
elif len(merged_sentences[-1]) + len(sentence) <= max_tokens_per_sentence:
merged_sentences[-1] = merged_sentences[-1] + sentence
else:
merged_sentences.append(sentence)
return merged_sentences
punctuation_marks_tokens = [
".",
"!",
"?",
"▁.",
# "▁!", # unk
"▁?",
"▁...", # ellipsis
]
def split_sentences(self, tokenized: List[str], max_tokens_per_sentence=120) -> List[List[str]]:
return TextTokenizer.split_sentences_by_token(
tokenized, self.punctuation_marks_tokens, max_tokens_per_sentence=max_tokens_per_sentence
)
if __name__ == "__main__":
# 测试程序
text_normalizer = TextNormalizer()
cases = [
"IndexTTS 正式发布1.0版本了,效果666",
"晕XUAN4是一种GAN3觉",
"我爱你!",
"I love you!",
"“我爱你”的英语是“I love you”",
"2.5平方电线",
"共465篇,约315万字",
"2002年的第一场雪,下在了2003年",
"速度是10km/h",
"现在是北京时间2025年01月11日 20:00",
"他这条裤子是2012年买的,花了200块钱",
"电话:135-4567-8900",
"1键3连",
"他这条视频点赞3000+,评论1000+,收藏500+",
"这是1024元的手机,你要吗?",
"受不liao3你了",
"“衣裳”不读衣chang2,而是读衣shang5",
"最zhong4要的是:不要chong2蹈覆辙",
"不zuo1死就不会死",
"See you at 8:00 AM",
"8:00 AM 开会",
"Couting down 3, 2, 1, go!",
"数到3就开始:1、2、3",
"This sales for 2.5% off, only $12.5.",
"5G网络是4G网络的升级版,2G网络是3G网络的前身",
"苹果于2030/1/2发布新 iPhone 2X 系列手机,最低售价仅 ¥12999",
"这酒...里...有毒...",
# 异常case
"只有,,,才是最好的",
"babala2是什么?", # babala二是什么?
"用beta1测试", # 用beta一测试
"have you ever been to beta2?", # have you ever been to beta two?
"such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS", # such as xtts,cosyvoice two,fish-speech,and f five-tts
"where's the money?", # where is the money?
"who's there?", # who is there?
"which's the best?", # which is the best?
"how's it going?", # how is it going?
"今天是个好日子 it's a good day", # 今天是个好日子 it is a good day
# 人名
"约瑟夫·高登-莱维特(Joseph Gordon-Levitt is an American actor)",
"蒂莫西·唐纳德·库克(英文名:Timothy Donald Cook),通称蒂姆·库克(Tim Cook),美国商业经理、工业工程师和工业开发商,现任苹果公司首席执行官。",
# 长句子
"《盗梦空间》是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰执导并编剧,莱昂纳多·迪卡普里奥、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由莱昂纳多·迪卡普里奥扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。",
"清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。\n宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。",
"电影1:“黑暗骑士”(演员:克里斯蒂安·贝尔、希斯·莱杰;导演:克里斯托弗·诺兰);电影2:“盗梦空间”(演员:莱昂纳多·迪卡普里奥;导演:克里斯托弗·诺兰);电影3:“钢琴家”(演员:艾德里安·布洛迪;导演:罗曼·波兰斯基);电影4:“泰坦尼克号”(演员:莱昂纳多·迪卡普里奥;导演:詹姆斯·卡梅隆);电影5:“阿凡达”(演员:萨姆·沃辛顿;导演:詹姆斯·卡梅隆);电影6:“南方公园:大电影”(演员:马特·斯通、托马斯·艾恩格瑞;导演:特雷·帕克)",
]
# 测试分词器
tokenizer = TextTokenizer(
vocab_file="checkpoints/bpe.model",
normalizer=text_normalizer,
)
codes = tokenizer.batch_encode(
cases,
out_type=int,
)
print(f"vocab_size: {tokenizer.vocab_size}")
# print(f"pad_token: {tokenizer.pad_token}, pad_token_id: {tokenizer.pad_token_id}")
print(f"bos_token: {tokenizer.bos_token}, bos_token_id: {tokenizer.bos_token_id}")
print(f"eos_token: {tokenizer.eos_token}, eos_token_id: {tokenizer.eos_token_id}")
print(f"unk_token: {tokenizer.unk_token}, unk_token_id: {tokenizer.unk_token_id}")
# 测试拼音 (8474-10201)
for id in range(8474, 10201):
pinyin = tokenizer.convert_ids_to_tokens(id)
if re.match(TextNormalizer.PINYIN_TONE_PATTERN, pinyin, re.IGNORECASE) is None:
print(f"{pinyin} should be matched")
for badcase in [
"beta1", "better1", "voice2", "bala2", "babala2", "hunger2"
]:
if re.match(TextNormalizer.PINYIN_TONE_PATTERN, badcase, re.IGNORECASE) is not None:
print(f"{badcase} should not be matched!")
# 不应该有 unk_token_id
for t in set([*TextTokenizer.punctuation_marks_tokens, ",", "▁,", "-", "▁..."]):
tokens = tokenizer.convert_tokens_to_ids(t)
if tokenizer.unk_token_id in tokens:
print(f"Warning: {t} is unknown token")
print(f"`{t}`", "->", tokens, "->", tokenizer.convert_ids_to_tokens(tokens))
for ch in set(tokenizer.normalizer.zh_char_rep_map.values()):
# 测试 normalize后的字符能被分词器识别
print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
max_tokens_per_sentence=120
for i in range(len(cases)):
print(f"原始文本: {cases[i]}")
print(f"Normalized: {text_normalizer.normalize(cases[i])}")
tokens = tokenizer.tokenize(cases[i])
print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens]))
sentences = tokenizer.split_sentences(tokens, max_tokens_per_sentence=max_tokens_per_sentence)
print("Splitted sentences count:", len(sentences))
if len(sentences) > 1:
for j in range(len(sentences)):
print(f" {j}, count:", len(sentences[j]), ", tokens:", "".join(sentences[j]))
if len(sentences[j]) > max_tokens_per_sentence:
print(f"Warning: sentence {j} is too long, length: {len(sentences[j])}")
#print(f"Token IDs (first 10): {codes[i][:10]}")
if tokenizer.unk_token in codes[i]:
print(f"Warning: `{cases[i]}` contains UNKNOWN token")
print(f"Decoded: {tokenizer.decode(codes[i], do_lower_case=True)}")
print("-" * 50)
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
from indextts.utils.maskgct.models.codec.amphion_codec.quantize import (
ResidualVQ,
VectorQuantize,
FactorizedVectorQuantize,
LookupFreeQuantize,
)
from indextts.utils.maskgct.models.codec.amphion_codec.vocos import Vocos
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class CodecEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
up_ratios: list = [4, 5, 5, 6],
out_channels: int = 256,
use_tanh: bool = False,
cfg=None,
):
super().__init__()
d_model = cfg.d_model if cfg is not None else d_model
up_ratios = cfg.up_ratios if cfg is not None else up_ratios
out_channels = cfg.out_channels if cfg is not None else out_channels
use_tanh = cfg.use_tanh if cfg is not None else use_tanh
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in up_ratios:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
]
if use_tanh:
self.block += [nn.Tanh()]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
self.reset_parameters()
def forward(self, x):
return self.block(x)
def reset_parameters(self):
self.apply(init_weights)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class CodecDecoder(nn.Module):
def __init__(
self,
in_channels: int = 256,
upsample_initial_channel: int = 1536,
up_ratios: list = [5, 5, 4, 2],
num_quantizers: int = 8,
codebook_size: int = 1024,
codebook_dim: int = 256,
quantizer_type: str = "vq",
quantizer_dropout: float = 0.5,
commitment: float = 0.25,
codebook_loss_weight: float = 1.0,
use_l2_normlize: bool = False,
codebook_type: str = "euclidean",
kmeans_init: bool = False,
kmeans_iters: int = 10,
decay: float = 0.8,
eps: float = 1e-5,
threshold_ema_dead_code: int = 2,
weight_init: bool = False,
use_vocos: bool = False,
vocos_dim: int = 384,
vocos_intermediate_dim: int = 1152,
vocos_num_layers: int = 8,
n_fft: int = 800,
hop_size: int = 200,
padding: str = "same",
cfg=None,
):
super().__init__()
in_channels = (
cfg.in_channels
if cfg is not None and hasattr(cfg, "in_channels")
else in_channels
)
upsample_initial_channel = (
cfg.upsample_initial_channel
if cfg is not None and hasattr(cfg, "upsample_initial_channel")
else upsample_initial_channel
)
up_ratios = (
cfg.up_ratios
if cfg is not None and hasattr(cfg, "up_ratios")
else up_ratios
)
num_quantizers = (
cfg.num_quantizers
if cfg is not None and hasattr(cfg, "num_quantizers")
else num_quantizers
)
codebook_size = (
cfg.codebook_size
if cfg is not None and hasattr(cfg, "codebook_size")
else codebook_size
)
codebook_dim = (
cfg.codebook_dim
if cfg is not None and hasattr(cfg, "codebook_dim")
else codebook_dim
)
quantizer_type = (
cfg.quantizer_type
if cfg is not None and hasattr(cfg, "quantizer_type")
else quantizer_type
)
quantizer_dropout = (
cfg.quantizer_dropout
if cfg is not None and hasattr(cfg, "quantizer_dropout")
else quantizer_dropout
)
commitment = (
cfg.commitment
if cfg is not None and hasattr(cfg, "commitment")
else commitment
)
codebook_loss_weight = (
cfg.codebook_loss_weight
if cfg is not None and hasattr(cfg, "codebook_loss_weight")
else codebook_loss_weight
)
use_l2_normlize = (
cfg.use_l2_normlize
if cfg is not None and hasattr(cfg, "use_l2_normlize")
else use_l2_normlize
)
codebook_type = (
cfg.codebook_type
if cfg is not None and hasattr(cfg, "codebook_type")
else codebook_type
)
kmeans_init = (
cfg.kmeans_init
if cfg is not None and hasattr(cfg, "kmeans_init")
else kmeans_init
)
kmeans_iters = (
cfg.kmeans_iters
if cfg is not None and hasattr(cfg, "kmeans_iters")
else kmeans_iters
)
decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
threshold_ema_dead_code = (
cfg.threshold_ema_dead_code
if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
else threshold_ema_dead_code
)
weight_init = (
cfg.weight_init
if cfg is not None and hasattr(cfg, "weight_init")
else weight_init
)
use_vocos = (
cfg.use_vocos
if cfg is not None and hasattr(cfg, "use_vocos")
else use_vocos
)
vocos_dim = (
cfg.vocos_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_dim
)
vocos_intermediate_dim = (
cfg.vocos_intermediate_dim
if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
else vocos_intermediate_dim
)
vocos_num_layers = (
cfg.vocos_num_layers
if cfg is not None and hasattr(cfg, "vocos_num_layers")
else vocos_num_layers
)
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
hop_size = (
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
)
padding = (
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
)
if quantizer_type == "vq":
self.quantizer = ResidualVQ(
input_dim=in_channels,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type=quantizer_type,
quantizer_dropout=quantizer_dropout,
commitment=commitment,
codebook_loss_weight=codebook_loss_weight,
use_l2_normlize=use_l2_normlize,
codebook_type=codebook_type,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
eps=eps,
threshold_ema_dead_code=threshold_ema_dead_code,
weight_init=weight_init,
)
elif quantizer_type == "fvq":
self.quantizer = ResidualVQ(
input_dim=in_channels,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type=quantizer_type,
quantizer_dropout=quantizer_dropout,
commitment=commitment,
codebook_loss_weight=codebook_loss_weight,
use_l2_normlize=use_l2_normlize,
)
elif quantizer_type == "lfq":
self.quantizer = ResidualVQ(
input_dim=in_channels,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type=quantizer_type,
)
else:
raise ValueError(f"Unknown quantizer type {quantizer_type}")
if not use_vocos:
# Add first conv layer
channels = upsample_initial_channel
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(up_ratios):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
if use_vocos:
self.model = Vocos(
input_channels=in_channels,
dim=vocos_dim,
intermediate_dim=vocos_intermediate_dim,
num_layers=vocos_num_layers,
adanorm_num_embeddings=None,
n_fft=n_fft,
hop_size=hop_size,
padding=padding,
)
self.reset_parameters()
def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
"""
if vq is True, x = encoder output, then return quantized output;
else, x = quantized output, then return decoder output
"""
if vq is True:
if eval_vq:
self.quantizer.eval()
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
all_quantized,
) = self.quantizer(x, n_quantizers=n_quantizers)
return (
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
all_quantized,
)
return self.model(x)
def quantize(self, x, n_quantizers=None):
self.quantizer.eval()
quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
return quantized_out, vq
# TODO: check consistency of vq2emb and quantize
def vq2emb(self, vq, n_quantizers=None):
return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
def decode(self, x):
return self.model(x)
def latent2dist(self, x, n_quantizers=None):
return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
def reset_parameters(self):
self.apply(init_weights)
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import (
FactorizedVectorQuantize,
)
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class FactorizedVectorQuantize(nn.Module):
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=0.005,
codebook_loss_weight=1.0,
use_l2_normlize=True,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.use_l2_normlize = use_l2_normlize
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(
self.codebook_dim, self.input_dim, kernel_size=1
)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
def forward(self, z):
"""
Parameters
----------
z: torch.Tensor[B x D x T]
Returns
-------
z_q: torch.Tensor[B x D x T]
Quantized continuous representation of input
commit_loss: Tensor[B]
Commitment loss to train encoder to predict vectors closer to codebook entries
codebook_loss: Tensor[B]
Codebook loss to update the codebook
indices: torch.Tensor[B x T]
Codebook indices (quantized discrete representation of input)
z_e: torch.Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
z_e = self.in_project(z)
z_q, indices = self.decode_latents(z_e)
# Compute commitment loss and codebook loss
if self.training:
commit_loss = (
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
* self.commitment
)
codebook_loss = (
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
* self.codebook_loss_weight
)
else:
commit_loss = torch.zeros(z.shape[0], device=z.device)
codebook_loss = torch.zeros(z.shape[0], device=z.device)
z_q = z_e + (z_q - z_e).detach()
z_q = self.out_project(z_q)
return z_q, commit_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight
# L2 normalize encodings and codebook
if self.use_l2_normlize:
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance between encodings and codebook,
# if use_l2_normlize is True, the distance is equal to cosine distance
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
def vq2emb(self, vq, out_proj=True):
emb = self.decode_code(vq)
if out_proj:
emb = self.out_project(emb)
return emb
def latent2dist(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight
# L2 normalize encodings and codebook
if self.use_l2_normlize:
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance between encodings and codebook,
# if use_l2_normlize is True, the distance is equal to cosine distance
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
) # (b*t, k)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
z_q = self.decode_code(indices)
return -dist, indices, z_q
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class LookupFreeQuantize(nn.Module):
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
assert 2**codebook_dim == codebook_size
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(
self.codebook_dim, self.input_dim, kernel_size=1
)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
def forward(self, z):
z_e = self.in_project(z)
z_e = F.sigmoid(z_e)
z_q = z_e + (torch.round(z_e) - z_e).detach()
z_q = self.out_project(z_q)
commit_loss = torch.zeros(z.shape[0], device=z.device)
codebook_loss = torch.zeros(z.shape[0], device=z.device)
bits = (
2
** torch.arange(self.codebook_dim, device=z.device)
.unsqueeze(0)
.unsqueeze(-1)
.long()
) # (1, d, 1)
indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
return z_q, commit_loss, codebook_loss, indices, z_e
def vq2emb(self, vq, out_proj=True):
emb = torch.zeros(
vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
) # (B, d, T)
for i in range(self.codebook_dim):
emb[:, i, :] = (vq % 2).float()
vq = vq // 2
if out_proj:
emb = self.out_project(emb)
return emb
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import (
FactorizedVectorQuantize,
)
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
class ResidualVQ(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 256,
num_quantizers: int = 8,
codebook_size: int = 1024,
codebook_dim: int = 256,
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
quantizer_dropout: float = 0.5,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_type = quantizer_type
self.quantizer_dropout = quantizer_dropout
if quantizer_type == "vq":
VQ = VectorQuantize
elif quantizer_type == "fvq":
VQ = FactorizedVectorQuantize
elif quantizer_type == "lfq":
VQ = LookupFreeQuantize
else:
raise ValueError(f"Unknown quantizer type {quantizer_type}")
self.quantizers = nn.ModuleList(
[
VQ(
input_dim=input_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
**kwargs,
)
for _ in range(num_quantizers)
]
)
def forward(self, z, n_quantizers: int = None):
"""
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
"quantized_out" : Tensor[B x D x T]
Quantized continuous representation of input
"all_indices" : Tensor[N x B x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"all_commit_losses" : Tensor[N]
"all_codebook_losses" : Tensor[N]
"all_quantized" : Tensor[N x B x D x T]
"""
quantized_out = 0.0
residual = z
all_commit_losses = []
all_codebook_losses = []
all_indices = []
all_quantized = []
if n_quantizers is None:
n_quantizers = self.num_quantizers
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
quantized_out = quantized_out + z_q_i * mask[:, None, None]
residual = residual - z_q_i
commit_loss_i = (commit_loss_i * mask).mean()
codebook_loss_i = (codebook_loss_i * mask).mean()
all_commit_losses.append(commit_loss_i)
all_codebook_losses.append(codebook_loss_i)
all_indices.append(indices_i)
all_quantized.append(z_q_i)
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
torch.stack,
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
)
return (
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
all_quantized,
)
def vq2emb(self, vq, n_quantizers=None):
quantized_out = 0.0
if n_quantizers is None:
n_quantizers = self.num_quantizers
for idx, quantizer in enumerate(self.quantizers):
if idx >= n_quantizers:
break
quantized_out += quantizer.vq2emb(vq[idx])
return quantized_out
def latent2dist(self, z, n_quantizers=None):
quantized_out = 0.0
residual = z
all_dists = []
all_indices = []
if n_quantizers is None:
n_quantizers = self.num_quantizers
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
all_dists.append(dist_i)
all_indices.append(indices_i)
quantized_out = quantized_out + z_q_i
residual = residual - z_q_i
all_dists = torch.stack(all_dists)
all_indices = torch.stack(all_indices)
return all_dists, all_indices
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories, eps=1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
kmeans_init=False,
kmeans_iters=10,
decay=0.8,
eps=1e-5,
threshold_ema_dead_code=2,
weight_init=False,
):
super().__init__()
self.decay = decay
init_fn = torch.randn if not weight_init else torch.zeros
embed = init_fn(codebook_size, dim)
if weight_init:
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer(
"initted", torch.Tensor([not kmeans_init])
) # if kmeans_init is True, then initted is False; otherwise, initted is True
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
def init_embed_(self, data):
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
def replace(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace(batch_samples, mask=expired_codes)
def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
if not self.initted:
self.init_embed_(flatten)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
if self.training:
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = (
flatten.t() @ embed_onehot
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
return quantize, embed_ind
def vq2emb(self, vq):
quantize = F.embedding(vq, self.embed)
return quantize
def latent2dist(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
if not self.initted:
self.init_embed_(flatten)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
dist = dist.view(*shape[:-1], -1)
return dist, embed_ind, quantize
class SimpleCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
use_l2_normlize=False,
):
super().__init__()
self.dim = dim
self.codebook_size = codebook_size
self.use_l2_normlize = use_l2_normlize
self.embed = nn.Embedding(self.codebook_size, self.dim)
def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
if self.use_l2_normlize:
flatten = F.normalize(flatten)
embed = F.normalize(embed)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
return quantize, embed_ind
def vq2emb(self, vq):
quantize = F.embedding(vq, self.embed.weight)
return quantize
def latent2dist(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "... d -> (...) d")
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
if self.use_l2_normlize:
flatten = F.normalize(flatten)
embed = F.normalize(embed)
dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
embed_ind = embed_ind.view(*shape[:-1])
quantize = F.embedding(embed_ind, self.embed)
dist = dist.view(*shape[:-1], -1)
return dist, embed_ind, quantize
class VectorQuantize(nn.Module):
"""Vector quantization and factorized vecotor quantization implementation
Args:
input_dim (int): Dimension of input.
codebook_size (int): Codebook size.
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
if use codebook_type == "euclidean", otherwise, if you want to use
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
commitment (float): Weight for commitment loss.
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
we suggest use it as True if you want to use factorized vector quantization
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=0.005,
codebook_loss_weight=1.0,
use_l2_normlize=False,
codebook_type="euclidean", # "euclidean" or "simple"
kmeans_init=False,
kmeans_iters=10,
decay=0.8,
eps=1e-5,
threshold_ema_dead_code=2,
weight_init=False,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.use_l2_normlize = use_l2_normlize
self.codebook_type = codebook_type
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.decay = decay
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.weight_init = weight_init
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(
self.codebook_dim, self.input_dim, kernel_size=1
)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
if self.codebook_type == "euclidean":
self.codebook = EuclideanCodebook(
self.codebook_dim,
codebook_size=self.codebook_size,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
decay=self.decay,
eps=self.eps,
threshold_ema_dead_code=self.threshold_ema_dead_code,
weight_init=self.weight_init,
)
elif self.codebook_type == "simple":
self.codebook = SimpleCodebook(
self.codebook_dim,
codebook_size=self.codebook_size,
use_l2_normlize=self.use_l2_normlize,
)
else:
raise NotImplementedError(
f"codebook_type {self.codebook_type} is not implemented!"
)
def forward(self, z):
"""
Parameters
----------
z: torch.Tensor[B x D x T]
Returns
-------
z_q: torch.Tensor[B x D x T]
Quantized continuous representation of input
commit_loss: Tensor[B]
Commitment loss to train encoder to predict vectors closer to codebook entries
codebook_loss: Tensor[B]
Codebook loss to update the codebook
indices: torch.Tensor[B x T]
Codebook indices (quantized discrete representation of input)
z_e: torch.Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
z_e = self.in_project(z)
z_q, indices = self.decode_latents(z_e)
# Compute commitment loss and codebook loss
if self.training:
commit_loss = (
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
* self.commitment
)
codebook_loss = (
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
* self.codebook_loss_weight
)
else:
commit_loss = torch.zeros(z.shape[0], device=z.device)
codebook_loss = torch.zeros(z.shape[0], device=z.device)
z_q = z_e + (z_q - z_e).detach()
z_q = self.out_project(z_q)
return z_q, commit_loss, codebook_loss, indices, z_e
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> b t d")
z_q, indices = self.codebook(encodings)
z_q = z_q.transpose(1, 2)
return z_q, indices
def vq2emb(self, vq, out_proj=True):
emb = self.codebook.vq2emb(vq)
emb = emb.transpose(1, 2)
if out_proj:
emb = self.out_project(emb)
return emb
def latent2dist(self, latents):
latents = rearrange(latents, "b d t -> b t d")
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
return dist, embed_ind, quantize.transpose(1, 2)
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import numpy as np
import scipy
import torch
from torch import nn, view_as_real, view_as_complex
from torch import nn
from torch.nn.utils import weight_norm, remove_weight_norm
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
import librosa
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
class STFT(nn.Module):
def __init__(
self,
n_fft: int,
hop_length: int,
win_length: int,
center=True,
):
super().__init__()
self.center = center
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T * hop_length)
if not self.center:
pad = self.win_length - self.hop_length
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
stft_spec = torch.stft(
x,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
return_complex=False,
) # (B, n_fft // 2 + 1, T, 2)
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
log_mag = torch.log(
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
) # (B, n_fft // 2 + 1, T)
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
return log_mag, phase
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
assert (window_envelope > 1e-11).all()
y = y / window_envelope
return y
class MDCT(nn.Module):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if self.padding == "center":
audio = torch.nn.functional.pad(
audio, (self.frame_len // 2, self.frame_len // 2)
)
elif self.padding == "same":
# hop_length is 1/2 frame_len
audio = torch.nn.functional.pad(
audio, (self.frame_len // 4, self.frame_len // 4)
)
else:
raise ValueError("Padding must be 'center' or 'same'.")
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
N = self.frame_len // 2
x = x * self.window.expand(x.shape)
X = torch.fft.fft(
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
)[..., :N]
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
return torch.real(res) * np.sqrt(2)
class IMDCT(nn.Module):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B, L, N = X.shape
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
Y[..., :N] = X
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
y = torch.fft.ifft(
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
)
y = (
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
* np.sqrt(N)
* np.sqrt(2)
)
result = y * self.window.expand(y.shape)
output_size = (1, (L + 1) * N)
audio = torch.nn.functional.fold(
result.transpose(1, 2),
output_size=output_size,
kernel_size=(1, self.frame_len),
stride=(1, self.frame_len // 2),
)[:, 0, 0, :]
if self.padding == "center":
pad = self.frame_len // 2
elif self.padding == "same":
pad = self.frame_len // 4
else:
raise ValueError("Padding must be 'center' or 'same'.")
audio = audio[:, pad:-pad]
return audio
class FourierHead(nn.Module):
"""Base class for inverse fourier modules."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class ISTFTHead(FourierHead):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(
mag, max=1e2
) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
class IMDCTSymExpHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
sample_rate: Optional[int] = None,
clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
self.out = nn.Linear(dim, out_dim)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
self.clip_audio = clip_audio
if sample_rate is not None:
# optionally init the last layer following mel-scale
m_max = _hz_to_mel(sample_rate // 2)
m_pts = torch.linspace(0, m_max, out_dim)
f_pts = _mel_to_hz(m_pts)
scale = 1 - (f_pts / f_pts.max())
with torch.no_grad():
self.out.weight.mul_(scale.view(-1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
x = symexp(x)
x = torch.clip(
x, min=-1e2, max=1e2
) # safeguard to prevent excessively large magnitudes
audio = self.imdct(x)
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class IMDCTCosHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
clip_audio: bool = False,
):
super().__init__()
self.clip_audio = clip_audio
self.out = nn.Linear(dim, mdct_frame_len)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
m, p = x.chunk(2, dim=2)
m = torch.exp(m).clip(
max=1e2
) # safeguard to prevent excessively large magnitudes
audio = self.imdct(m * torch.cos(p))
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: float,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class AdaLayerNorm(nn.Module):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embedding_dim
)
self.shift = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embedding_dim
)
torch.nn.init.ones_(self.scale.weight)
torch.nn.init.zeros_(self.shift.weight)
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding_id)
shift = self.shift(cond_embedding_id)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale + shift
return x
class ResBlock1(nn.Module):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[0],
padding=self.get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[1],
padding=self.get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[2],
padding=self.get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
]
)
self.gamma = nn.ParameterList(
[
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
xt = c2(xt)
if gamma is not None:
xt = gamma * xt
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
@staticmethod
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get("bandwidth_id", None)
x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self,
input_channels,
dim,
num_blocks,
layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
)
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
self.resnet = nn.Sequential(
*[
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
for _ in range(num_blocks)
]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.embed(x)
x = self.resnet(x)
x = x.transpose(1, 2)
return x
class Vocos(nn.Module):
def __init__(
self,
input_channels: int = 256,
dim: int = 384,
intermediate_dim: int = 1152,
num_layers: int = 8,
n_fft: int = 800,
hop_size: int = 200,
padding: str = "same",
adanorm_num_embeddings=None,
cfg=None,
):
super().__init__()
input_channels = (
cfg.input_channels
if cfg is not None and hasattr(cfg, "input_channels")
else input_channels
)
dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
intermediate_dim = (
cfg.intermediate_dim
if cfg is not None and hasattr(cfg, "intermediate_dim")
else intermediate_dim
)
num_layers = (
cfg.num_layers
if cfg is not None and hasattr(cfg, "num_layers")
else num_layers
)
adanorm_num_embeddings = (
cfg.adanorm_num_embeddings
if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
else adanorm_num_embeddings
)
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
hop_size = (
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
)
padding = (
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
)
self.backbone = VocosBackbone(
input_channels=input_channels,
dim=dim,
intermediate_dim=intermediate_dim,
num_layers=num_layers,
adanorm_num_embeddings=adanorm_num_embeddings,
)
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x[:, None, :]
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable
import torch
import numpy as np
import torch.utils.data
from torch.nn.utils.rnn import pad_sequence
from utils.data_utils import *
from torch.utils.data import ConcatDataset, Dataset
class CodecDataset(torch.utils.data.Dataset):
def __init__(self, cfg, dataset, is_valid=False):
"""
Args:
cfg: config
dataset: dataset name
is_valid: whether to use train or valid dataset
"""
assert isinstance(dataset, str)
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
self.metafile_path = os.path.join(processed_data_dir, meta_file)
self.metadata = self.get_metadata()
self.data_root = processed_data_dir
self.cfg = cfg
if cfg.preprocess.use_audio:
self.utt2audio_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2audio_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.audio_dir,
uid + ".npy",
)
elif cfg.preprocess.use_label:
self.utt2label_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2label_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.label_dir,
uid + ".npy",
)
elif cfg.preprocess.use_one_hot:
self.utt2one_hot_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2one_hot_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.one_hot_dir,
uid + ".npy",
)
if cfg.preprocess.use_mel:
self.utt2mel_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2mel_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.mel_dir,
uid + ".npy",
)
if cfg.preprocess.use_frame_pitch:
self.utt2frame_pitch_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2frame_pitch_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.pitch_dir,
uid + ".npy",
)
if cfg.preprocess.use_uv:
self.utt2uv_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2uv_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.uv_dir,
uid + ".npy",
)
if cfg.preprocess.use_amplitude_phase:
self.utt2logamp_path = {}
self.utt2pha_path = {}
self.utt2rea_path = {}
self.utt2imag_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2logamp_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.log_amplitude_dir,
uid + ".npy",
)
self.utt2pha_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.phase_dir,
uid + ".npy",
)
self.utt2rea_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.real_dir,
uid + ".npy",
)
self.utt2imag_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.imaginary_dir,
uid + ".npy",
)
def __getitem__(self, index):
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
single_feature = dict()
if self.cfg.preprocess.use_mel:
mel = np.load(self.utt2mel_path[utt])
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
if "target_len" not in single_feature.keys():
single_feature["target_len"] = mel.shape[1]
single_feature["mel"] = mel
if self.cfg.preprocess.use_frame_pitch:
frame_pitch = np.load(self.utt2frame_pitch_path[utt])
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_pitch)
aligned_frame_pitch = align_length(
frame_pitch, single_feature["target_len"]
)
single_feature["frame_pitch"] = aligned_frame_pitch
if self.cfg.preprocess.use_audio:
audio = np.load(self.utt2audio_path[utt])
single_feature["audio"] = audio
return single_feature
def get_metadata(self):
with open(self.metafile_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
return metadata
def get_dataset_name(self):
return self.metadata[0]["Dataset"]
def __len__(self):
return len(self.metadata)
class CodecConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
"""Concatenate a series of datasets with their random inference audio merged."""
super().__init__(datasets)
self.cfg = self.datasets[0].cfg
self.metadata = []
# Merge metadata
for dataset in self.datasets:
self.metadata += dataset.metadata
# Merge random inference features
if full_audio_inference:
self.eval_audios = []
self.eval_dataset_names = []
if self.cfg.preprocess.use_mel:
self.eval_mels = []
if self.cfg.preprocess.use_frame_pitch:
self.eval_pitchs = []
for dataset in self.datasets:
self.eval_audios.append(dataset.eval_audio)
self.eval_dataset_names.append(dataset.get_dataset_name())
if self.cfg.preprocess.use_mel:
self.eval_mels.append(dataset.eval_mel)
if self.cfg.preprocess.use_frame_pitch:
self.eval_pitchs.append(dataset.eval_pitch)
class CodecCollator(object):
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
self.cfg = cfg
def __call__(self, batch):
packed_batch_features = dict()
# mel: [b, n_mels, frame]
# frame_pitch: [b, frame]
# audios: [b, frame * hop_size]
for key in batch[0].keys():
if key == "target_len":
packed_batch_features["target_len"] = torch.LongTensor(
[b["target_len"] for b in batch]
)
masks = [
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
]
packed_batch_features["mask"] = pad_sequence(
masks, batch_first=True, padding_value=0
)
elif key == "mel":
values = [torch.from_numpy(b[key]).T for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=0
)
else:
values = [torch.from_numpy(b[key]) for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=0
)
return packed_batch_features
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import json
import json5
import time
import accelerate
import random
import numpy as np
import shutil
from pathlib import Path
from tqdm import tqdm
from glob import glob
from accelerate.logging import get_logger
from torch.utils.data import DataLoader
from models.vocoders.vocoder_dataset import (
VocoderDataset,
VocoderCollator,
VocoderConcatDataset,
)
from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
from models.vocoders.flow.waveglow import waveglow
from models.vocoders.diffusion.diffwave import diffwave
from models.vocoders.autoregressive.wavenet import wavenet
from models.vocoders.autoregressive.wavernn import wavernn
from models.vocoders.gan import gan_vocoder_inference
from models.vocoders.diffusion import diffusion_vocoder_inference
from utils.io import save_audio
_vocoders = {
"diffwave": diffwave.DiffWave,
"wavernn": wavernn.WaveRNN,
"wavenet": wavenet.WaveNet,
"waveglow": waveglow.WaveGlow,
"nsfhifigan": nsfhifigan.NSFHiFiGAN,
"bigvgan": bigvgan.BigVGAN,
"hifigan": hifigan.HiFiGAN,
"melgan": melgan.MelGAN,
"apnet": apnet.APNet,
}
# Forward call for generalized Inferencor
_vocoder_forward_funcs = {
# "world": world_inference.synthesis_audios,
# "wavernn": wavernn_inference.synthesis_audios,
# "wavenet": wavenet_inference.synthesis_audios,
"diffwave": diffusion_vocoder_inference.vocoder_inference,
"nsfhifigan": gan_vocoder_inference.vocoder_inference,
"bigvgan": gan_vocoder_inference.vocoder_inference,
"melgan": gan_vocoder_inference.vocoder_inference,
"hifigan": gan_vocoder_inference.vocoder_inference,
"apnet": gan_vocoder_inference.vocoder_inference,
}
# APIs for other tasks. e.g. SVC, TTS, TTA...
_vocoder_infer_funcs = {
# "world": world_inference.synthesis_audios,
# "wavernn": wavernn_inference.synthesis_audios,
# "wavenet": wavenet_inference.synthesis_audios,
"diffwave": diffusion_vocoder_inference.synthesis_audios,
"nsfhifigan": gan_vocoder_inference.synthesis_audios,
"bigvgan": gan_vocoder_inference.synthesis_audios,
"melgan": gan_vocoder_inference.synthesis_audios,
"hifigan": gan_vocoder_inference.synthesis_audios,
"apnet": gan_vocoder_inference.synthesis_audios,
}
class VocoderInference(object):
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
super().__init__()
start = time.monotonic_ns()
self.args = args
self.cfg = cfg
self.infer_type = infer_type
# Init accelerator
self.accelerator = accelerate.Accelerator()
self.accelerator.wait_for_everyone()
# Get logger
with self.accelerator.main_process_first():
self.logger = get_logger("inference", log_level=args.log_level)
# Log some info
self.logger.info("=" * 56)
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
self.logger.info("=" * 56)
self.logger.info("\n")
self.vocoder_dir = args.vocoder_dir
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
os.makedirs(args.output_dir, exist_ok=True)
if os.path.exists(os.path.join(args.output_dir, "pred")):
shutil.rmtree(os.path.join(args.output_dir, "pred"))
if os.path.exists(os.path.join(args.output_dir, "gt")):
shutil.rmtree(os.path.join(args.output_dir, "gt"))
os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
# Set random seed
with self.accelerator.main_process_first():
start = time.monotonic_ns()
self._set_random_seed(self.cfg.train.random_seed)
end = time.monotonic_ns()
self.logger.debug(
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
)
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
# Setup inference mode
if self.infer_type == "infer_from_dataset":
self.cfg.dataset = self.args.infer_datasets
elif self.infer_type == "infer_from_feature":
self._build_tmp_dataset_from_feature()
self.cfg.dataset = ["tmp"]
elif self.infer_type == "infer_from_audio":
self._build_tmp_dataset_from_audio()
self.cfg.dataset = ["tmp"]
# Setup data loader
with self.accelerator.main_process_first():
self.logger.info("Building dataset...")
start = time.monotonic_ns()
self.test_dataloader = self._build_dataloader()
end = time.monotonic_ns()
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
# Build model
with self.accelerator.main_process_first():
self.logger.info("Building model...")
start = time.monotonic_ns()
self.model = self._build_model()
end = time.monotonic_ns()
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
# Init with accelerate
self.logger.info("Initializing accelerate...")
start = time.monotonic_ns()
self.accelerator = accelerate.Accelerator()
(self.model, self.test_dataloader) = self.accelerator.prepare(
self.model, self.test_dataloader
)
end = time.monotonic_ns()
self.accelerator.wait_for_everyone()
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
with self.accelerator.main_process_first():
self.logger.info("Loading checkpoint...")
start = time.monotonic_ns()
if os.path.isdir(args.vocoder_dir):
if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
else:
self._load_model(os.path.join(args.vocoder_dir))
else:
self._load_model(os.path.join(args.vocoder_dir))
end = time.monotonic_ns()
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
self.model.eval()
self.accelerator.wait_for_everyone()
def _build_tmp_dataset_from_feature(self):
if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
utts = []
mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
for i, mel in enumerate(mels):
uid = mel.split("/")[-1].split(".")[0]
utt = {"Dataset": "tmp", "Uid": uid, "index": i}
utts.append(utt)
os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
with open(
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
) as f:
json.dump(utts, f)
meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
with open(
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
"w",
) as f:
json.dump(meta_info, f)
features = glob(os.path.join(self.args.feature_folder, "*"))
for feature in features:
feature_name = feature.split("/")[-1]
if os.path.isfile(feature):
continue
shutil.copytree(
os.path.join(self.args.feature_folder, feature_name),
os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
)
def _build_tmp_dataset_from_audio(self):
if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
utts = []
audios = glob(os.path.join(self.args.audio_folder, "*"))
for i, audio in enumerate(audios):
uid = audio.split("/")[-1].split(".")[0]
utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
utts.append(utt)
os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
with open(
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
) as f:
json.dump(utts, f)
meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
with open(
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
"w",
) as f:
json.dump(meta_info, f)
from processors import acoustic_extractor
acoustic_extractor.extract_utt_acoustic_features_serial(
utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
)
def _build_test_dataset(self):
return VocoderDataset, VocoderCollator
def _build_model(self):
model = _vocoders[self.cfg.model.generator](self.cfg)
return model
def _build_dataloader(self):
"""Build dataloader which merges a series of datasets."""
Dataset, Collator = self._build_test_dataset()
datasets_list = []
for dataset in self.cfg.dataset:
subdataset = Dataset(self.cfg, dataset, is_valid=True)
datasets_list.append(subdataset)
test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
test_collate = Collator(self.cfg)
test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
test_dataloader = DataLoader(
test_dataset,
collate_fn=test_collate,
num_workers=1,
batch_size=test_batch_size,
shuffle=False,
)
self.test_batch_size = test_batch_size
self.test_dataset = test_dataset
return test_dataloader
def _load_model(self, checkpoint_dir, from_multi_gpu=False):
"""Load model from checkpoint. If a folder is given, it will
load the latest checkpoint in checkpoint_dir. If a path is given
it will load the checkpoint specified by checkpoint_path.
**Only use this method after** ``accelerator.prepare()``.
"""
if os.path.isdir(checkpoint_dir):
if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
checkpoint_path = checkpoint_dir
else:
# Load the latest accelerator state dicts
ls = [
str(i)
for i in Path(checkpoint_dir).glob("*")
if not "audio" in str(i)
]
ls.sort(
key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
reverse=True,
)
checkpoint_path = ls[0]
accelerate.load_checkpoint_and_dispatch(
self.accelerator.unwrap_model(self.model),
os.path.join(checkpoint_path, "pytorch_model.bin"),
)
return str(checkpoint_path)
else:
# Load old .pt checkpoints
if self.cfg.model.generator in [
"bigvgan",
"hifigan",
"melgan",
"nsfhifigan",
]:
ckpt = torch.load(
checkpoint_dir,
map_location=(
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
),
)
if from_multi_gpu:
pretrained_generator_dict = ckpt["generator_state_dict"]
generator_dict = self.model.state_dict()
new_generator_dict = {
k.split("module.")[-1]: v
for k, v in pretrained_generator_dict.items()
if (
k.split("module.")[-1] in generator_dict
and v.shape == generator_dict[k.split("module.")[-1]].shape
)
}
generator_dict.update(new_generator_dict)
self.model.load_state_dict(generator_dict)
else:
self.model.load_state_dict(ckpt["generator_state_dict"])
else:
self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
return str(checkpoint_dir)
def inference(self):
"""Inference via batches"""
for i, batch in tqdm(enumerate(self.test_dataloader)):
if self.cfg.preprocess.use_frame_pitch:
audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
self.cfg,
self.model,
batch["mel"].transpose(-1, -2),
f0s=batch["frame_pitch"].float(),
device=next(self.model.parameters()).device,
)
else:
audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
self.cfg,
self.model,
batch["mel"].transpose(-1, -2),
device=next(self.model.parameters()).device,
)
audio_ls = audio_pred.chunk(self.test_batch_size)
audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
j = 0
for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
l = l.item()
it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
save_audio(
os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
it,
self.cfg.preprocess.sample_rate,
)
save_audio(
os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
it_gt,
self.cfg.preprocess.sample_rate,
)
j += 1
if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
def _set_random_seed(self, seed):
"""Set random seed for all possible random modules."""
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def _count_parameters(self, model):
return sum(p.numel() for p in model.parameters())
def _dump_cfg(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
json5.dump(
self.cfg,
open(path, "w"),
indent=4,
sort_keys=True,
ensure_ascii=False,
quote_keys=True,
)
def load_nnvocoder(
cfg,
vocoder_name,
weights_file,
from_multi_gpu=False,
):
"""Load the specified vocoder.
cfg: the vocoder config filer.
weights_file: a folder or a .pt path.
from_multi_gpu: automatically remove the "module" string in state dicts if "True".
"""
print("Loading Vocoder from Weights file: {}".format(weights_file))
# Build model
model = _vocoders[vocoder_name](cfg)
if not os.path.isdir(weights_file):
# Load from .pt file
if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
ckpt = torch.load(
weights_file,
map_location=(
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
),
)
if from_multi_gpu:
pretrained_generator_dict = ckpt["generator_state_dict"]
generator_dict = model.state_dict()
new_generator_dict = {
k.split("module.")[-1]: v
for k, v in pretrained_generator_dict.items()
if (
k.split("module.")[-1] in generator_dict
and v.shape == generator_dict[k.split("module.")[-1]].shape
)
}
generator_dict.update(new_generator_dict)
model.load_state_dict(generator_dict)
else:
model.load_state_dict(ckpt["generator_state_dict"])
else:
model.load_state_dict(torch.load(weights_file)["state_dict"])
else:
# Load from accelerator state dict
weights_file = os.path.join(weights_file, "checkpoint")
ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
checkpoint_path = ls[0]
accelerator = accelerate.Accelerator()
model = accelerator.prepare(model)
accelerator.load_state(checkpoint_path)
if torch.cuda.is_available():
model = model.cuda()
model = model.eval()
return model
def tensorize(data, device, n_samples):
"""
data: a list of numpy array
"""
assert type(data) == list
if n_samples:
data = data[:n_samples]
data = [torch.as_tensor(x, device=device) for x in data]
return data
def synthesis(
cfg,
vocoder_weight_file,
n_samples,
pred,
f0s=None,
batch_size=64,
fast_inference=False,
):
"""Synthesis audios from a given vocoder and series of given features.
cfg: vocoder config.
vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
"""
vocoder_name = cfg.model.generator
print("Synthesis audios using {} vocoder...".format(vocoder_name))
###### TODO: World Vocoder Refactor ######
# if vocoder_name == "world":
# world_inference.synthesis_audios(
# cfg, dataset_name, split, n_samples, pred, save_dir, tag
# )
# return
# ====== Loading neural vocoder model ======
vocoder = load_nnvocoder(
cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
)
device = next(vocoder.parameters()).device
# ====== Inference for predicted acoustic features ======
# pred: (frame_len, n_mels) -> (n_mels, frame_len)
mels_pred = tensorize([p.T for p in pred], device, n_samples)
print("For predicted mels, #sample = {}...".format(len(mels_pred)))
audios_pred = _vocoder_infer_funcs[vocoder_name](
cfg,
vocoder,
mels_pred,
f0s=f0s,
batch_size=batch_size,
fast_inference=fast_inference,
)
return audios_pred
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment