Commit a7785cc6 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

delete soft link

parent 9a2a05ca
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
import yaml
import torch
from collections import OrderedDict
import datetime
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
if torch.cuda.is_available():
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
checkpoint = torch.load(path)
else:
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint, strict=False)
info_path = re.sub('.pt$', '.yaml', path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
return configs
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
infos (dict or None): any info you want to save.
'''
logging.info('Checkpoint: save to checkpoint %s' % path)
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)
def filter_modules(model_state_dict, modules):
new_mods = []
incorrect_mods = []
mods_model = model_state_dict.keys()
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.warning(
"module(s) %s don't match or (partially match) "
"available modules in model.",
incorrect_mods,
)
logging.warning("for information, the existing modules in model are:")
logging.warning("%s", mods_model)
return new_mods
def load_trained_modules(model: torch.nn.Module, args: None):
# Load encoder modules with pre-trained model(s).
enc_model_path = args.enc_init
enc_modules = args.enc_init_mods
main_state_dict = model.state_dict()
logging.warning("model(s) found for pre-initialization")
if os.path.isfile(enc_model_path):
logging.info('Checkpoint: loading from checkpoint %s for CPU' %
enc_model_path)
model_state_dict = torch.load(enc_model_path, map_location='cpu')
modules = filter_modules(model_state_dict, enc_modules)
partial_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
partial_state_dict[key] = value
main_state_dict.update(partial_state_dict)
else:
logging.warning("model was not found : %s", enc_model_path)
model.load_state_dict(main_state_dict)
configs = {}
return configs
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import numpy as np
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logging.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file, is_json):
if is_json:
cmvn = _load_json_cmvn(cmvn_file)
else:
cmvn = _load_kaldi_cmvn(cmvn_file)
return cmvn[0], cmvn[1]
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
import math
from typing import List, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
IGNORE_ID = -1
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max([x.size(0) for x in xs])
pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
pad = pad.fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
def add_blank(ys_pad: torch.Tensor, blank: int,
ignore_id: int) -> torch.Tensor:
""" Prepad blank for transducer predictor
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
blank (int): index of <blank>
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> blank = 0
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in = add_blank(ys_pad, 0, -1)
>>> ys_in
tensor([[0, 1, 2, 3, 4, 5],
[0, 4, 5, 6, 0, 0],
[0, 7, 8, 9, 0, 0]])
"""
bs = ys_pad.size(0)
_blank = torch.tensor([blank],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_blank = _blank.repeat(bs).unsqueeze(1) # [bs,1]
out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1]
return torch.where(out == ignore_id, blank, out)
def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
ys_out (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = torch.tensor([sos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_eos = torch.tensor([eos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def reverse_pad_list(ys_pad: torch.Tensor,
ys_lens: torch.Tensor,
pad_value: float = -1.0) -> torch.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True,
pad_value)
return r_ys_pad
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def get_rnn(rnn_type: str) -> torch.nn.Module:
assert rnn_type in ["rnn", "lstm", "gru"]
if rnn_type == "rnn":
return torch.nn.RNN
elif rnn_type == "lstm":
return torch.nn.LSTM
else:
return torch.nn.GRU
def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
from wenet.transformer.swish import Swish
activation_funcs = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU
}
return activation_funcs[act]()
def get_subsample(config):
input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
new_hyp.append(hyp[cur])
prev = cur
cur += 1
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0:
new_hyp.append(0)
cur += 1
return new_hyp
def log_add(args: List[int]) -> float:
"""
Stable log add
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# This file is originally copied from tools/compute-wer.py and modified to calculate the char accuracy
import re, sys, unicodedata
import codecs
remove_tag = True
spacelist= [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?',
'、', '。', '!', ',', ';', '?',
':', '「', '」', '︰', '『', '』', '《', '》']
def characterize(string) :
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i+1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c==sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0; T=len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
else:
new_sentence.append(x)
return new_sentence
class Calculator :
def __init__(self) :
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec) :
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab) :
self.space.append([])
for row in self.space :
for element in row :
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec) :
row.append({'dist' : 0, 'error' : 'non'})
for i in range(len(lab)) :
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)) :
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
for token in rec :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
# Computing edit distance
for i, lab_token in enumerate(lab) :
for j, rec_token in enumerate(rec) :
if i == 0 or j == 0 :
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i-1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist :
min_dist = dist
min_error = error
dist = self.space[i][j-1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist :
min_dist = dist
min_error = error
if lab_token == rec_token :
dist = self.space[i-1][j-1]['dist'] + self.cost['cor']
error = 'cor'
else :
dist = self.space[i-1][j-1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist :
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
i = len(lab) - 1
j = len(rec) - 1
while True :
if self.space[i][j]['error'] == 'cor' : # correct
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub' : # substitution
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del' : # deletion
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins' : # insertion
if len(rec[j]) > 0 :
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non' : # starting point
break
else : # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
return result
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in data :
if token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self) :
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ]
for i in reversed(range(len(unicode_names))) :
if unicode_names[i].startswith('DIGIT') : # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) :
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')) :
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or
unicode_names[i].startswith('COMMERCIAL AT') or
unicode_names[i].startswith('DEGREE CELSIUS') or
unicode_names[i].startswith('EQUALS SIGN') or
unicode_names[i].startswith('FULL STOP') or
unicode_names[i].startswith('HYPHEN-MINUS') or
unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')) :
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else :
return 'Other'
if len(unicode_names) == 0 :
return 'Other'
if len(unicode_names) == 1 :
return unicode_names[0]
for i in range(len(unicode_names)-1) :
if unicode_names[i] != unicode_names[i+1] :
return 'Other'
return unicode_names[0]
def compute_char_acc(args):
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = True
verbose= 1
padding_symbol= ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
if not case_sensitive:
ig=set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = args.val_ref_file
hyp_file = args.val_hyp_file
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array)==0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8') :
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array)==0: continue
fid = array[0]
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
#if verbose:
# print('\nutt: %s' % fid)
for word in rec + lab :
if word not in default_words :
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters :
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name] :
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
#print('WER: %4.2f %%' % wer, end = ' ')
#print('N=%d C=%d S=%d D=%d I=%d' %
# (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])) :
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length-len_lab)
space['rec'].append(length-len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
#if verbose > 1:
# print('lab(%s):' % fid.encode('utf-8'), end = ' ')
#else:
# print('lab:', end = ' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
#print('{token}'.format(token = token), end = '')
#for n in range(space['lab'][idx]) :
# print(padding_symbol, end = '')
#print(' ',end='')
#print()
#if verbose > 1:
# print('rec(%s):' % fid.encode('utf-8'), end = ' ')
#else:
# print('rec:', end = ' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
#print('{token}'.format(token = token), end = '')
#for n in range(space['rec'][idx]) :
# print(padding_symbol, end = '')
#print(' ',end='')
#print('\n', end='\n')
lab1 = lab2
rec1 = rec2
#if verbose:
# print('===========================================================================')
# print()
result = calculator.overall()
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
#print('Overall -> %4.2f %%' % wer, end = ' ')
#print('N=%d C=%d S=%d D=%d I=%d' %
# (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
#if not verbose:
# print()
char_acc = 100.0 - wer
return char_acc
if verbose:
for cluster_id in default_clusters :
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
#print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
#print('N=%d C=%d S=%d D=%d I=%d' %
# (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
#print()
#print('===========================================================================')
# Copyright (c) 2021 Shaoshang Qi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
def override_config(configs, override_list):
new_configs = copy.deepcopy(configs)
for item in override_list:
arr = item.split()
if len(arr) != 2:
print(f"the overrive {item} format not correct, skip it")
continue
keys = arr[0].split('.')
s_configs = new_configs
for i, key in enumerate(keys):
if key not in s_configs:
print(f"the overrive {item} format not correct, skip it")
if i == len(keys) - 1:
param_type = type(s_configs[key])
if param_type != bool:
s_configs[key] = param_type(arr[1])
else:
s_configs[key] = arr[1] in ['true', 'True']
print(f"override {arr[0]} with {arr[1]}")
else:
s_configs = s_configs[key]
return new_configs
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
def insert_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1)
label = label.reshape(-1)
label = np.append(label, label[0])
return label
def forced_align(ctc_probs: torch.Tensor,
y: torch.Tensor,
blank_id=0) -> list:
"""ctc forced alignment.
Args:
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
Returns:
torch.Tensor: alignment result
"""
y_insert_blank = insert_blank(y, blank_id)
log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
log_alpha = log_alpha - float('inf') # log of zero
state_path = (torch.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
) # state path
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = torch.tensor(
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = torch.tensor([
log_alpha[t - 1, s],
log_alpha[t - 1, s - 1],
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
state_path[t, s] = prev_state[torch.argmax(candidates)]
state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
candidates = torch.tensor([
log_alpha[-1, len(y_insert_blank) - 1],
log_alpha[-1, len(y_insert_blank) - 2]
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[torch.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
# if your python version < 3.7 use the below one
# from contextlib import suppress as nullcontext
import torch
from torch.nn.utils import clip_grad_norm_
from wenet.utils.global_vars import get_global_steps, global_steps_inc, get_num_trained_samples, num_trained_samples_inc
import time
class Executor:
def __init__(self):
self.step = 0
def train(self, model, optimizer, scheduler, data_loader, device, writer,
args, scaler):
''' Train one epoch
'''
model.train()
clip = args.get('grad_clip', 50.0)
log_interval = args.get('log_interval', 10)
rank = args.get('rank', 0)
epoch = args.get('epoch', 0)
accum_grad = args.get('accum_grad', 1)
is_distributed = args.get('is_distributed', True)
use_amp = args.get('use_amp', False)
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(accum_grad))
if use_amp:
assert scaler is not None
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_context = model.join
else:
model_context = nullcontext
num_seen_utts = 0
with model_context():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
target_lengths = target_lengths.to(device)
num_utts = target_lengths.size(0)
if num_utts == 0:
continue
context = None
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if is_distributed and batch_idx % accum_grad != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
with torch.cuda.amp.autocast(scaler is not None):
loss_dict = model(feats, feats_lengths, target,
target_lengths)
loss = loss_dict['loss'] / accum_grad
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
num_seen_utts += num_utts
global_steps_inc()
num_trained_samples_inc(num_utts)
if batch_idx % accum_grad == 0:
#if rank == 0 and writer is not None:
# writer.add_scalar('train_loss', loss, self.step)
# Use mixed precision training
if use_amp:
scaler.unscale_(optimizer)
grad_norm = clip_grad_norm_(model.parameters(), clip)
# Must invoke scaler.update() if unscale_() is used in
# the iteration to avoid the following error:
# RuntimeError: unscale_() has already been called
# on this optimizer since the last update().
# We don't check grad here since that if the gradient
# has inf/nan values, scaler.step will skip
# optimizer.step().
scaler.step(optimizer)
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
self.step += 1
#if batch_idx % log_interval == 0:
# lr = optimizer.param_groups[0]['lr']
# log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format(
# epoch, batch_idx,
# loss.item() * accum_grad)
# for name, value in loss_dict.items():
# if name != 'loss' and value is not None:
# log_str += '{} {:.6f} '.format(name, value.item())
# log_str += 'lr {:.8f} rank {}'.format(lr, rank)
# logging.debug(log_str)
lr = optimizer.param_groups[0]['lr']
loss_str = "%.4f" % (loss.item() * accum_grad)
global_steps = get_global_steps()
num_trained_samples = get_num_trained_samples()
step_output = f'[PerfLog] {{"event": "STEP_END", "value": {{"epoch": {epoch+1}, "global_steps": {global_steps},"loss": {loss_str},"num_trained_samples": {num_trained_samples}, "learning_rate": {lr:.9f}}}}}'
logging.info(f'rank {rank}: ' + step_output)
def cv(self, model, data_loader, device, args):
''' Cross validation on
'''
model.eval()
rank = args.get('rank', 0)
epoch = args.get('epoch', 0)
log_interval = args.get('log_interval', 10)
# in order to avoid division by 0
num_seen_utts = 1
total_loss = 0.0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
target_lengths = target_lengths.to(device)
num_utts = target_lengths.size(0)
if num_utts == 0:
continue
loss_dict = model(feats, feats_lengths, target, target_lengths)
loss = loss_dict['loss']
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts
if batch_idx % log_interval == 0:
log_str = 'CV Batch {}/{} loss {:.6f} '.format(
epoch, batch_idx, loss.item())
for name, value in loss_dict.items():
if name != 'loss' and value is not None:
log_str += '{} {:.6f} '.format(name, value.item())
log_str += 'history loss {:.6f}'.format(total_loss /
num_seen_utts)
log_str += ' rank {}'.format(rank)
logging.debug(log_str)
return total_loss, num_seen_utts
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
return lists
def read_non_lang_symbols(non_lang_sym_path):
"""read non-linguistic symbol from file.
The file format is like below:
{NOISE}\n
{BRK}\n
...
Args:
non_lang_sym_path: non-linguistic symbol file path, None means no any
syms.
"""
if non_lang_sym_path is None:
return None
else:
syms = read_lists(non_lang_sym_path)
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
for sym in syms:
if non_lang_syms_pattern.fullmatch(sym) is None:
class BadSymbolFormat(Exception):
pass
raise BadSymbolFormat(
"Non-linguistic symbols should be "
"formatted in {xxx}/<xxx>/[xxx], consider"
" modify '%s' to meet the requirment. "
"More details can be found in discussions here : "
"https://github.com/wenet-e2e/wenet/pull/819" % (sym))
return syms
def read_symbol_table(symbol_table_file):
symbol_table = {}
with open(symbol_table_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
symbol_table[arr[0]] = int(arr[1])
return symbol_table
# Global variables and their getters and setters for cross python packages
# global_steps
global_steps = 0
def get_global_steps():
return global_steps
def set_global_steps(value):
global global_steps
global_steps = value
def global_steps_inc():
global global_steps
global_steps += 1
# num_trained_samples
num_trained_samples = 0
def get_num_trained_samples():
return num_trained_samples
def set_num_trained_samples(value):
global num_trained_samples
num_trained_samples = value
def num_trained_samples_inc(value):
global num_trained_samples
num_trained_samples += value
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from wenet.transducer.joint import TransducerJoint
from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor,
RNNPredictor)
from wenet.transducer.transducer import Transducer
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.cmvn import GlobalCMVN
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder
from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.utils.cmvn import load_cmvn
def init_model(configs):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'conformer')
decoder_type = configs.get('decoder', 'bitransformer')
if encoder_type == 'conformer':
encoder = ConformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
elif encoder_type == 'squeezeformer':
encoder = SqueezeformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
elif encoder_type == 'efficientConformer':
encoder = EfficientConformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'],
**configs['encoder_conf']['efficient_conf']
if 'efficient_conf' in
configs['encoder_conf'] else {})
else:
encoder = TransformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
if decoder_type == 'transformer':
decoder = TransformerDecoder(vocab_size, encoder.output_size(),
**configs['decoder_conf'])
else:
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
assert configs['decoder_conf']['r_num_blocks'] > 0
decoder = BiTransformerDecoder(vocab_size, encoder.output_size(),
**configs['decoder_conf'])
ctc = CTC(vocab_size, encoder.output_size())
# Init joint CTC/Attention or Transducer model
if 'predictor' in configs:
predictor_type = configs.get('predictor', 'rnn')
if predictor_type == 'rnn':
predictor = RNNPredictor(vocab_size, **configs['predictor_conf'])
elif predictor_type == 'embedding':
predictor = EmbeddingPredictor(vocab_size,
**configs['predictor_conf'])
configs['predictor_conf']['output_size'] = configs[
'predictor_conf']['embed_size']
elif predictor_type == 'conv':
predictor = ConvPredictor(vocab_size, **configs['predictor_conf'])
configs['predictor_conf']['output_size'] = configs[
'predictor_conf']['embed_size']
else:
raise NotImplementedError(
"only rnn, embedding and conv type support now")
configs['joint_conf']['enc_output_size'] = configs['encoder_conf'][
'output_size']
configs['joint_conf']['pred_output_size'] = configs['predictor_conf'][
'output_size']
joint = TransducerJoint(vocab_size, **configs['joint_conf'])
model = Transducer(vocab_size=vocab_size,
blank=0,
predictor=predictor,
encoder=encoder,
attention_decoder=decoder,
joint=joint,
ctc=ctc,
**configs['model_conf'])
else:
model = ASRModel(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment