Commit 51782715 authored by liugh5's avatar liugh5
Browse files

update

parent 8b4e9acd
from enum import Enum
class Tone(Enum):
UnAssigned = -1
NoneTone = 0
YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress
YangPing = 2 # ZhHK: YinShang EnUS: secondary stress
ShangSheng = 3 # ZhHK: YinQuZhongRu
QuSheng = 4 # ZhHK: YangPing
QingSheng = 5 # ZhHK: YangShang
YangQuYangRu = 6 # ZhHK: YangQuYangRu
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(Tone, cls).__new__(cls, in_str)
if in_str in ["UnAssigned", "-1"]:
return Tone.UnAssigned
elif in_str in ["NoneTone", "0"]:
return Tone.NoneTone
elif in_str in ["YinPing", "1"]:
return Tone.YinPing
elif in_str in ["YangPing", "2"]:
return Tone.YangPing
elif in_str in ["ShangSheng", "3"]:
return Tone.ShangSheng
elif in_str in ["QuSheng", "4"]:
return Tone.QuSheng
elif in_str in ["QingSheng", "5"]:
return Tone.QingSheng
elif in_str in ["YangQuYangRu", "6"]:
return Tone.YangQuYangRu
else:
return Tone.NoneTone
class BreakLevel(Enum):
UnAssigned = -1
L0 = 0
L1 = 1
L2 = 2
L3 = 3
L4 = 4
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(BreakLevel, cls).__new__(cls, in_str)
if in_str in ["UnAssigned", "-1"]:
return BreakLevel.UnAssigned
elif in_str in ["L0", "0"]:
return BreakLevel.L0
elif in_str in ["L1", "1"]:
return BreakLevel.L1
elif in_str in ["L2", "2"]:
return BreakLevel.L2
elif in_str in ["L3", "3"]:
return BreakLevel.L3
elif in_str in ["L4", "4"]:
return BreakLevel.L4
else:
return BreakLevel.UnAssigned
class SentencePurpose(Enum):
Declarative = 0
Interrogative = 1
Exclamatory = 2
Imperative = 3
class Language(Enum):
Neutral = 0
EnUS = 1033
EnGB = 2057
ZhCN = 2052
PinYin = 2053
WuuShanghai = 2054
Sichuan = 2055
ZhHK = 3076
ZhEn = ZhCN | EnUS
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(Language, cls).__new__(cls, in_str)
if in_str in ["Neutral", "0"]:
return Language.Neutral
elif in_str in ["EnUS", "1033"]:
return Language.EnUS
elif in_str in ["EnGB", "2057"]:
return Language.EnGB
elif in_str in ["ZhCN", "2052"]:
return Language.ZhCN
elif in_str in ["PinYin", "2053"]:
return Language.PinYin
elif in_str in ["WuuShanghai", "2054"]:
return Language.WuuShanghai
elif in_str in ["Sichuan", "2055"]:
return Language.Sichuan
elif in_str in ["ZhHK", "3076"]:
return Language.ZhHK
elif in_str in ["ZhEn", "2052|1033"]:
return Language.ZhEn
else:
return Language.Neutral
"""
Phone Types
"""
class PhoneCVType(Enum):
NULL = -1
Consonant = 1
Vowel = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneCVType, cls).__new__(cls, in_str)
if in_str in ["consonant", "Consonant"]:
return PhoneCVType.Consonant
elif in_str in ["vowel", "Vowel"]:
return PhoneCVType.Vowel
else:
return PhoneCVType.NULL
class PhoneIFType(Enum):
NULL = -1
Initial = 1
Final = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneIFType, cls).__new__(cls, in_str)
if in_str in ["initial", "Initial"]:
return PhoneIFType.Initial
elif in_str in ["final", "Final"]:
return PhoneIFType.Final
else:
return PhoneIFType.NULL
class PhoneUVType(Enum):
NULL = -1
Voiced = 1
UnVoiced = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneUVType, cls).__new__(cls, in_str)
if in_str in ["voiced", "Voiced"]:
return PhoneUVType.Voiced
elif in_str in ["unvoiced", "UnVoiced"]:
return PhoneUVType.UnVoiced
else:
return PhoneUVType.NULL
class PhoneAPType(Enum):
NULL = -1
DoubleLips = 1
LipTooth = 2
FrontTongue = 3
CentralTongue = 4
BackTongue = 5
Dorsal = 6
Velar = 7
Low = 8
Middle = 9
High = 10
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneAPType, cls).__new__(cls, in_str)
if in_str in ["doublelips", "DoubleLips"]:
return PhoneAPType.DoubleLips
elif in_str in ["liptooth", "LipTooth"]:
return PhoneAPType.LipTooth
elif in_str in ["fronttongue", "FrontTongue"]:
return PhoneAPType.FrontTongue
elif in_str in ["centraltongue", "CentralTongue"]:
return PhoneAPType.CentralTongue
elif in_str in ["backtongue", "BackTongue"]:
return PhoneAPType.BackTongue
elif in_str in ["dorsal", "Dorsal"]:
return PhoneAPType.Dorsal
elif in_str in ["velar", "Velar"]:
return PhoneAPType.Velar
elif in_str in ["low", "Low"]:
return PhoneAPType.Low
elif in_str in ["middle", "Middle"]:
return PhoneAPType.Middle
elif in_str in ["high", "High"]:
return PhoneAPType.High
else:
return PhoneAPType.NULL
class PhoneAMType(Enum):
NULL = -1
Stop = 1
Affricate = 2
Fricative = 3
Nasal = 4
Lateral = 5
Open = 6
Close = 7
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneAMType, cls).__new__(cls, in_str)
if in_str in ["stop", "Stop"]:
return PhoneAMType.Stop
elif in_str in ["affricate", "Affricate"]:
return PhoneAMType.Affricate
elif in_str in ["fricative", "Fricative"]:
return PhoneAMType.Fricative
elif in_str in ["nasal", "Nasal"]:
return PhoneAMType.Nasal
elif in_str in ["lateral", "Lateral"]:
return PhoneAMType.Lateral
elif in_str in ["open", "Open"]:
return PhoneAMType.Open
elif in_str in ["close", "Close"]:
return PhoneAMType.Close
else:
return PhoneAMType.NULL
import re
import unicodedata
import codecs
WordPattern = r"((?P<Word>\w+)(\(\w+\))?)"
BreakPattern = r"(?P<Break>(\*?#(?P<BreakLevel>[0-4])))"
MarkPattern = r"(?P<Mark>[、,。!?:“”《》·])"
POSPattern = r"(?P<POS>(\*?\|(?P<POSClass>[1-9])))"
PhraseTonePattern = r"(?P<PhraseTone>(\*?%([L|H])))"
NgBreakPattern = r"^ng(?P<break>\d)"
RegexWord = re.compile(WordPattern + r"\s*")
RegexBreak = re.compile(BreakPattern + r"\s*")
RegexID = re.compile(r"^(?P<ID>.*?)\s")
RegexSentence = re.compile(
r"({}|{}|{}|{}|{})\s*".format(
WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern
)
)
RegexForeignLang = re.compile(r"[A-Z@]")
RegexSpace = re.compile(r"^\s*")
RegexNeutralTone = re.compile(r"[1-5]5")
def do_character_normalization(line):
return unicodedata.normalize("NFKC", line)
def do_prosody_text_normalization(line):
tokens = line.split("\t")
text = tokens[1]
# Remove punctuations
text = text.replace(u"。", " ")
text = text.replace(u"、", " ")
text = text.replace(u"“", " ")
text = text.replace(u"”", " ")
text = text.replace(u"‘", " ")
text = text.replace(u"’", " ")
text = text.replace(u"|", " ")
text = text.replace(u"《", " ")
text = text.replace(u"》", " ")
text = text.replace(u"【", " ")
text = text.replace(u"】", " ")
text = text.replace(u"—", " ")
text = text.replace(u"―", " ")
text = text.replace(".", " ")
text = text.replace("!", " ")
text = text.replace("?", " ")
text = text.replace("(", " ")
text = text.replace(")", " ")
text = text.replace("[", " ")
text = text.replace("]", " ")
text = text.replace("{", " ")
text = text.replace("}", " ")
text = text.replace("~", " ")
text = text.replace(":", " ")
text = text.replace(";", " ")
text = text.replace("+", " ")
text = text.replace(",", " ")
# text = text.replace('·', ' ')
text = text.replace('"', " ")
text = text.replace(
"-", ""
) # don't replace by space because compond word like two-year-old
text = text.replace(
"'", ""
) # don't replace by space because English word like that's
# Replace break
text = text.replace("/", "#2")
text = text.replace("%", "#3")
# Remove useless spaces surround #2 #3 #4
text = re.sub(r"(#\d)[ ]+", r"\1", text)
text = re.sub(r"[ ]+(#\d)", r"\1", text)
# Replace space by #1
text = re.sub("[ ]+", "#1", text)
# Remove break at the end of the text
text = re.sub(r"#\d$", "", text)
# Add #1 between target language and foreign language
text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r"\1#1\2", text)
text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r"\1#1\2", text)
return tokens[0] + "\t" + text
def is_fp_line(line):
fp_category_list = ["FP", "I", "N", "Q"]
elements = line.strip().split(" ")
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
def format_prosody(src_prosody):
formatted_lines = []
with codecs.open(src_prosody, "r", "utf-8") as f:
lines = f.readlines()
idx = 0
while idx < len(lines):
line = do_character_normalization(lines[idx])
if len(line.strip().split("\t")) == 2:
line = do_prosody_text_normalization(line)
else:
fp_enable = is_fp_line(line)
if fp_enable:
idx += 3
continue
formatted_lines.append(line)
idx += 1
# with codecs.open(tgt_prosody, 'w', 'utf-8') as f:
# f.writelines(formatted_lines)
return formatted_lines
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
from .layers import (DenseLayer, DenseTDNNBlock, StatsPool, TDNNLayer, SEDenseTDNNBlock,
TransitLayer)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class CNN_Head(nn.Module):
def __init__(self,
block=BasicBlock,
num_blocks=[2, 2],
m_channels=32,
feat_dim=80):
super(CNN_Head, self).__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
out = out.reshape(out.shape[0], out.shape[1]*out.shape[2], out.shape[3])
return out
class DTDNN(nn.Module):
def __init__(self,
feat_dim=80,
embedding_size=192,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True):
super(DTDNN, self).__init__()
self.head = CNN_Head()
feat_dim = self.head.out_channels
self.xvector = nn.Sequential(
OrderedDict([
('tdnn',
TDNNLayer(feat_dim,
init_channels,
5,
stride=2,
dilation=1,
padding=-1,
config_str=config_str)),
]))
channels = init_channels
for i, (num_layers, kernel_size,
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 3))):
block = SEDenseTDNNBlock(num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient)
self.xvector.add_module('block%d' % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
'transit%d' % (i + 1),
TransitLayer(channels,
channels // 2,
bias=False,
config_str=config_str))
channels //= 2
self.bn = nn.BatchNorm1d(channels)
self.relu = nn.ReLU(inplace=True)
self.xvector.add_module('stats', StatsPool())
self.xvector.add_module(
'dense',
DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector.tdnn(x)
x = self.xvector.block1(x)
x = self.xvector.transit1(x)
x = self.xvector.block2(x)
x = self.xvector.transit2(x)
x = self.xvector.block3(x)
x = self.xvector.transit3(x)
x = self.relu(self.bn(x))
x = self.xvector.stats(x)
x = self.xvector.dense(x)
return x
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import nn
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
nonlinear.add_module('relu', nn.ReLU(inplace=True))
elif name == 'prelu':
nonlinear.add_module('prelu', nn.PReLU(channels))
elif name == 'batchnorm':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm',
nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError('Unexpected module ({}).'.format(name))
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
def high_order_statistics_pooling(x,
dim=-1,
keepdim=False,
unbiased=True,
eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
norm = (x - mean.unsqueeze(dim=dim)) \
/ std.clamp(min=eps).unsqueeze(dim=dim)
skewness = norm.pow(3).mean(dim=dim)
kurtosis = norm.pow(4).mean(dim=dim)
stats = torch.cat([mean, std, skewness, kurtosis], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(nn.Module):
def forward(self, x):
return statistics_pooling(x)
class HighOrderStatsPool(nn.Module):
def forward(self, x):
return high_order_statistics_pooling(x)
class TDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
super(TDNNLayer, self).__init__()
if padding < 0:
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.linear = nn.Conv1d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
x = self.linear(x)
x = self.nonlinear(x)
return x
class DenseTDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(DenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.linear2 = nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
if self.training and self.memory_efficient:
x = cp.checkpoint(self.bn_function, x)
else:
x = self.bn_function(x)
x = self.linear2(self.nonlinear2(x))
return x
class DenseTDNNBlock(nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(DenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = DenseTDNNLayer(in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class StatsSelect(nn.Module):
def __init__(self, channels, branches, null=False, reduction=1):
super(StatsSelect, self).__init__()
self.gather = HighOrderStatsPool()
self.linear1 = nn.Conv1d(channels * 4, channels // reduction, 1)
self.linear2 = nn.ModuleList()
if null:
branches += 1
for _ in range(branches):
self.linear2.append(nn.Conv1d(channels // reduction, channels, 1))
self.channels = channels
self.branches = branches
self.null = null
self.reduction = reduction
def forward(self, x):
f = torch.cat([_x.unsqueeze(dim=1) for _x in x], dim=1)
x = torch.sum(f, dim=1)
x = self.linear1(self.gather(x).unsqueeze(dim=-1))
s = []
for linear in self.linear2:
s.append(linear(x).view(-1, 1, self.channels))
s = torch.cat(s, dim=1)
s = F.softmax(s, dim=1).unsqueeze(dim=-1)
if self.null:
s = s[:, :-1, :, :]
return torch.sum(f * s, dim=1)
def extra_repr(self):
return 'channels={}, branches={}, reduction={}'.format(
self.channels, self.branches, self.reduction)
class SqueezeExcitation(nn.Module):
def __init__(self, channels, reduction=1):
super(SqueezeExcitation, self).__init__()
self.linear1 = nn.Conv1d(channels, channels // reduction, 1)
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Conv1d(channels // reduction, channels, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
s = self.linear1(x.mean(-1, keepdim=True)+self.seg_pooling(x))
s = self.relu(s)
s = self.sigmoid(self.linear2(s))
return x*s
def seg_pooling(self, x, seg_len=100):
s_x = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
out = s_x.unsqueeze(-1).expand(-1, -1, -1, seg_len).reshape(*x.shape[:-1], -1)
out = out[:, :, :x.shape[-1]]
return out
class PoolingBlock(nn.Module):
def __init__(self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2):
super(PoolingBlock, self).__init__()
self.linear_stem = nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
self.relu = nn.ReLU(inplace=True)
# self.bn = nn.BatchNorm1d(out_channels)
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = nn.Sigmoid()
# self.linear3 = nn.Conv1d(out_channels, out_channels, 1)
def forward(self, x):
y = self.linear_stem(x)
s = self.linear1(x.mean(-1, keepdim=True)+self.seg_pooling(x))
s = self.relu(s)
s = self.sigmoid(self.linear2(s))
return y*s
def seg_pooling(self, x, seg_len=100):
s_x = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
out = s_x.unsqueeze(-1).expand(-1, -1, -1, seg_len).reshape(*x.shape[:-1], -1)
out = out[:, :, :x.shape[-1]]
return out
class MultiBranchDenseTDNNLayer(DenseTDNNLayer):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=(1, ),
bias=False,
null=False,
reduction=1,
config_str='batchnorm-relu',
memory_efficient=False):
super(DenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2
if not isinstance(dilation, (tuple, list)):
dilation = (dilation, )
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.linear2 = nn.ModuleList()
for _dilation in dilation:
self.linear2.append(
nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding * _dilation,
dilation=_dilation,
bias=bias))
self.select = StatsSelect(out_channels,
len(dilation),
null=null,
reduction=reduction)
def forward(self, x):
if self.training and self.memory_efficient:
x = cp.checkpoint(self.bn_function, x)
else:
x = self.bn_function(x)
x = self.nonlinear2(x)
x = self.select([linear(x) for linear in self.linear2])
return x
class SEDenseTDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(SEDenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
# self.linear2 = nn.Conv1d(bn_channels,
# out_channels,
# kernel_size,
# stride=stride,
# padding=padding,
# dilation=dilation,
# bias=bias)
# self.se = SqueezeExcitation(out_channels)
self.se = PoolingBlock(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
if self.training and self.memory_efficient:
x = cp.checkpoint(self.bn_function, x)
else:
x = self.bn_function(x)
# x = self.linear2(self.nonlinear2(x))
x = self.se(self.nonlinear2(x))
return x
class SEDenseTDNNBlock(nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(SEDenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = SEDenseTDNNLayer(in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class MultiBranchDenseTDNNBlock(DenseTDNNBlock):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
null=False,
reduction=1,
config_str='batchnorm-relu',
memory_efficient=False):
super(DenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = MultiBranchDenseTDNNLayer(
in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
null=null,
reduction=reduction,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
class TransitLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True,
config_str='batchnorm-relu'):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
x = self.linear(x)
return x
class DenseLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=False,
config_str='batchnorm-relu'):
super(DenseLayer, self).__init__()
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if len(x.shape) == 2:
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
else:
x = self.linear(x)
x = self.nonlinear(x)
return x
if __name__ == '__main__':
model = SqueezeExcitation(channels=32)
model.eval()
x = torch.randn(1, 32, 298)
y = model(x)
print(y.size())
from thop import profile
macs, num_params = profile(model, inputs=(x, ))
# num_params = sum(p.numel() for p in model.parameters())
print("MACs: {} G".format(macs / 1e9))
print("Params: {} M".format(num_params / 1e6))
\ No newline at end of file
import torch
import torchaudio
import numpy as np
import os
import torchaudio.compliance.kaldi as Kaldi
from .D_TDNN import DTDNN
import logging
import argparse
from glob import glob
logging.basicConfig(
format="%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.DEBUG,
)
class SpeakerEmbeddingProcessor:
def __init__(self, sample_rate=16000):
self.sample_rate = sample_rate
self.min_wav_length = self.sample_rate * 30 * 10 / 1000
self.pcm_dict = {}
self.mfcc_dict = {}
self.se_list = []
def process(self, src_voice_dir, se_model):
logging.info("[SpeakerEmbeddingProcessor] Speaker embedding extractor started")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DTDNN()
try:
if os.path.basename(se_model) == "se.model":
model.load_state_dict(torch.load(se_model, map_location=device))
else:
raise Exception("[SpeakerEmbeddingProcessor] se model loading error!!!")
except Exception as e:
logging.info(e)
if os.path.basename(se_model) == 'se.onnx':
logging.info("[SpeakerEmbeddingProcessor] please update your se model to ensure that the version is greater than or equal to 1.0.5")
sys.exit()
model.eval()
model.to(device)
wav_dir = os.path.join(src_voice_dir, "wav")
se_dir = os.path.join(src_voice_dir, "se")
se_average_file = os.path.join(se_dir, "se.npy")
os.makedirs(se_dir, exist_ok=True)
wav_files = glob(os.path.join(wav_dir, '*.wav'))
for wav_file in wav_files:
basename = os.path.splitext(os.path.basename(wav_file))[0]
se_file = os.path.join(se_dir, basename + '.npy')
wav, fs = torchaudio.load(wav_file)
assert wav.shape[0] == 1
assert fs == 16000
if wav.shape[1] < self.min_wav_length:
continue
fbank_feat = Kaldi.fbank(wav, num_mel_bins=80)
feat = fbank_feat - fbank_feat.mean(dim=0, keepdim=True)
feat = feat.unsqueeze(0).to(device)
speaker_embedding = model(feat)
speaker_embedding = speaker_embedding.squeeze().cpu().detach().numpy()
speaker_embedding = np.expand_dims(speaker_embedding, axis=0)
np.save(se_file, speaker_embedding)
self.se_list.append(speaker_embedding)
self.se_average = np.expand_dims(
np.mean(
np.concatenate(self.se_list, axis=0),
axis=0
),
axis=0
)
np.save(se_average_file, self.se_average)
logging.info("[SpeakerEmbeddingProcessor] Speaker embedding extracted successfully!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Speaker Embedding Processor")
parser.add_argument("--src_voice_dir", type=str, required=True)
parser.add_argument('--se_model', required=True)
args = parser.parse_args()
sep = SpeakerEmbeddingProcessor()
sep.process(args.src_voice_dir, args.se_onnx)
\ No newline at end of file
import logging
import os
import sys
import argparse
import yaml
import time
import zipfile
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.datasets.dataset import BERT_Text_Dataset
from kantts.utils.log import logging_to_file, get_git_revision_hash
from kantts.utils.ling_unit import text_to_mit_symbols as text_to_symbols
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def gen_metafile(
output_dir,
split_ratio=0.98,
):
raw_metafile = os.path.join(output_dir, "raw_metafile.txt")
bert_train_meta = os.path.join(output_dir, "bert_train.lst")
bert_valid_meta = os.path.join(output_dir, "bert_valid.lst")
if not os.path.exists(
bert_train_meta) or not os.path.exists(bert_valid_meta):
BERT_Text_Dataset.gen_metafile(raw_metafile, output_dir, split_ratio)
logging.info("BERT Text metafile generated.")
# TODO: Zh-CN as default
def process_mit_style_data(
text_file,
resources_zip_file,
output_dir,
):
os.makedirs(output_dir, exist_ok=True)
logging_to_file(os.path.join(output_dir, "data_process_stdout.log"))
resource_root_dir = os.path.dirname(resources_zip_file)
resource_dir = os.path.join(resource_root_dir, "resource")
if not os.path.exists(resource_dir):
logging.info("Extracting resources...")
with zipfile.ZipFile(resources_zip_file, "r") as zip_ref:
zip_ref.extractall(resource_root_dir)
with open(text_file, "r") as text_data:
texts = text_data.readlines()
logging.info("Converting text to symbols...")
symbols_lst = text_to_symbols(texts, resource_dir, "F7")
symbols_file = os.path.join(output_dir, "raw_metafile.txt")
with open(symbols_file, "w") as symbol_data:
for symbol in symbols_lst:
symbol_data.write(symbol)
logging.info("Processing done.")
# Generate BERT Text metafile
# TODO: train/valid ratio setting
gen_metafile(output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Dataset preprocessor")
parser.add_argument("--text_file", type=str, required=True)
parser.add_argument("--resources_zip_file", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
process_mit_style_data(
args.text_file,
args.resources_zip_file,
args.output_dir,
)
import torch
import torch.nn.functional as F
from kantts.utils.audio_torch import stft, MelSpectrogram
from kantts.models.utils import get_mask_from_lengths
class MelReconLoss(torch.nn.Module):
def __init__(self, loss_type="mae"):
super(MelReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == "mae":
self.criterion = torch.nn.L1Loss(reduction="none")
elif loss_type == "mse":
self.criterion = torch.nn.MSELoss(reduction="none")
else:
raise ValueError("Unknown loss type: {}".format(loss_type))
def forward(self, output_lengths, mel_targets, dec_outputs, postnet_outputs=None):
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1)
)
output_masks = ~output_masks
valid_outputs = output_masks.sum()
mel_loss_ = torch.sum(
self.criterion(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)
) / (valid_outputs * mel_targets.size(-1))
if postnet_outputs is not None:
mel_loss = torch.sum(
self.criterion(mel_targets, postnet_outputs)
* output_masks.unsqueeze(-1)
) / (valid_outputs * mel_targets.size(-1))
else:
mel_loss = 0.0
return mel_loss_, mel_loss
class ProsodyReconLoss(torch.nn.Module):
def __init__(self, loss_type="mae"):
super(ProsodyReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == "mae":
self.criterion = torch.nn.L1Loss(reduction="none")
elif loss_type == "mse":
self.criterion = torch.nn.MSELoss(reduction="none")
else:
raise ValueError("Unknown loss type: {}".format(loss_type))
def forward(
self,
input_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
):
input_masks = get_mask_from_lengths(
input_lengths, max_len=duration_targets.size(1)
)
input_masks = ~input_masks
valid_inputs = input_masks.sum()
dur_loss = (
torch.sum(
self.criterion(
torch.log(duration_targets.float() + 1), log_duration_predictions
)
* input_masks
)
/ valid_inputs
)
pitch_loss = (
torch.sum(self.criterion(pitch_targets, pitch_predictions) * input_masks)
/ valid_inputs
)
energy_loss = (
torch.sum(self.criterion(energy_targets, energy_predictions) * input_masks)
/ valid_inputs
)
return dur_loss, pitch_loss, energy_loss
class FpCELoss(torch.nn.Module):
def __init__(self, loss_type="ce", weight=[1, 4, 4, 8]):
super(FpCELoss, self).__init__()
self.loss_type = loss_type
weight_ce = torch.FloatTensor(weight).cuda()
self.criterion = torch.nn.CrossEntropyLoss(weight=weight_ce, reduction="none")
def forward(self, input_lengths, fp_pd, fp_label):
input_masks = get_mask_from_lengths(input_lengths, max_len=fp_label.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
fp_loss = (
torch.sum(self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
/ valid_inputs
)
return fp_loss
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse",
):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Args:
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs.
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse",
):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x):
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x):
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x):
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers=True,
average_by_discriminators=True,
):
"""Initialize FeatureMatchLoss module."""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
def forward(self, feats_hat, feats):
"""Calcualate feature matching loss.
Args:
feats_hat (list): List of list of discriminator outputs
calcuated from generater outputs.
feats (list): List of list of discriminator outputs
calcuated from groundtruth.
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base,
)
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(
self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"
):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
# NOTE(kan-bayashi): Use register_buffer to fix #223
self.register_buffer("window", getattr(torch, window)(win_length))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window",
):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
class SeqCELoss(torch.nn.Module):
def __init__(self, loss_type="ce"):
super(SeqCELoss, self).__init__()
self.loss_type = loss_type
self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self, logits, targets, masks):
loss = self.criterion(
logits.contiguous().view(-1, logits.size(-1)), targets.contiguous().view(-1)
)
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
masks = masks.contiguous().view(-1)
loss = (loss * masks).sum() / masks.sum()
err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
return loss, err
class AttentionBinarizationLoss(torch.nn.Module):
def __init__(self, start_epoch=0, warmup_epoch=100):
super(AttentionBinarizationLoss, self).__init__()
self.start_epoch = start_epoch
self.warmup_epoch = warmup_epoch
def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
log_sum = torch.log(
torch.clamp(soft_attention[hard_attention == 1], min=eps)
).sum()
kl_loss = -log_sum / hard_attention.sum()
if epoch < self.start_epoch:
warmup_ratio = 0
else:
warmup_ratio = min(1.0, (epoch - self.start_epoch) / self.warmup_epoch)
return kl_loss * warmup_ratio
class AttentionCTCLoss(torch.nn.Module):
def __init__(self, blank_logprob=-1):
super(AttentionCTCLoss, self).__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.blank_logprob = blank_logprob
self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = F.pad(
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
)
cost_total = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
curr_logprob = curr_logprob[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
ctc_cost = self.CTCLoss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
cost_total += ctc_cost
cost = cost_total / attn_logprob.shape[0]
return cost
# TODO: create a mapping for new loss functions
loss_dict = {
"generator_adv_loss": GeneratorAdversarialLoss,
"discriminator_adv_loss": DiscriminatorAdversarialLoss,
"stft_loss": MultiResolutionSTFTLoss,
"mel_loss": MelSpectrogramLoss,
"subband_stft_loss": MultiResolutionSTFTLoss,
"feat_match_loss": FeatureMatchLoss,
"MelReconLoss": MelReconLoss,
"ProsodyReconLoss": ProsodyReconLoss,
"SeqCELoss": SeqCELoss,
"AttentionBinarizationLoss": AttentionBinarizationLoss,
"AttentionCTCLoss": AttentionCTCLoss,
"FpCELoss": FpCELoss,
}
def criterion_builder(config, device="cpu"):
"""Criterion builder.
Args:
config (dict): Config dictionary.
Returns:
criterion (dict): Loss dictionary
"""
criterion = {}
for key, value in config["Loss"].items():
if key in loss_dict:
if value["enable"]:
criterion[key] = loss_dict[key](**value.get("params", {})).to(device)
setattr(criterion[key], "weights", value.get("weights", 1.0))
else:
raise NotImplementedError("{} is not implemented".format(key))
return criterion
from torch.optim.lr_scheduler import * # NOQA
from torch.optim.lr_scheduler import _LRScheduler # NOQA
"""Noam Scheduler."""
class FindLR(_LRScheduler):
"""
inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
"""
def __init__(self, optimizer, max_steps, max_lr=10):
self.max_steps = max_steps
self.max_lr = max_lr
super().__init__(optimizer)
def get_lr(self):
return [
base_lr
* ((self.max_lr / base_lr) ** (self.last_epoch / (self.max_steps - 1)))
for base_lr in self.base_lrs
]
class NoamLR(_LRScheduler):
"""
Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
to the inverse square root of the step number, scaled by the inverse square root of the
dimensionality of the model. Time will tell if this is just madness or it's actually important.
Parameters
----------
warmup_steps: ``int``, required.
The number of steps to linearly increase the learning rate.
"""
def __init__(self, optimizer, warmup_steps):
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
last_epoch = max(1, self.last_epoch)
scale = self.warmup_steps ** 0.5 * min(
last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)
)
return [base_lr * scale for base_lr in self.base_lrs]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment