Commit a7785cc6 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

delete soft link

parent 9a2a05ca
#!/usr/bin/env python3
# coding:utf-8
# Copyright (c) 2022 SDCI Co. Ltd (author: veelion)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import time
import asyncio
import argparse
import websockets
import soundfile as sf
import statistics
WS_START = json.dumps({
'signal': 'start',
'nbest': 1,
'continuous_decoding': False,
})
WS_END = json.dumps({
'signal': 'end'
})
async def ws_rec(data, ws_uri):
begin = time.time()
conn = await websockets.connect(ws_uri, ping_timeout=200)
# step 1: send start
await conn.send(WS_START)
ret = await conn.recv()
# step 2: send audio data
await conn.send(data)
# step 3: send end
await conn.send(WS_END)
# step 4: receive result
texts = []
while 1:
ret = await conn.recv()
ret = json.loads(ret)
if ret['type'] == 'final_result':
nbest = json.loads(ret['nbest'])
text = nbest[0]['sentence']
texts.append(text)
elif ret['type'] == 'speech_end':
break
# step 5: close
try:
await conn.close()
except Exception as e:
# this except has no effect, just log as debug
# it seems the server does not send close info, maybe
print(e)
time_cost = time.time() - begin
return {
'text': ''.join(texts),
'time': time_cost,
}
def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument(
'-u', '--ws_uri', required=True,
help="websocket_server_main's uri, e.g. ws://127.0.0.1:10086")
parser.add_argument(
'-w', '--wav_scp', required=True,
help='path to wav_scp_file')
parser.add_argument(
'-t', '--trans', required=True,
help='path to trans_text_file of wavs')
parser.add_argument(
'-s', '--save_to', required=True,
help='path to save transcription')
parser.add_argument(
'-n', '--num_concurrence', type=int, required=True,
help='num of concurrence for query')
args = parser.parse_args()
return args
def print_result(info):
length = max([len(k) for k in info])
for k, v in info.items():
print(f'\t{k: >{length}} : {v}')
async def main(args):
wav_scp = []
total_duration = 0
with open(args.wav_scp) as f:
for line in f:
zz = line.strip().split()
assert len(zz) == 2
data, sr = sf.read(zz[1], dtype='int16')
assert sr == 16000
duration = (len(data)) / 16000
total_duration += duration
wav_scp.append((zz[0], data.tobytes()))
print(f'{len(wav_scp) = }, {total_duration = }')
tasks = []
failed = 0
texts = []
request_times = []
begin = time.time()
for i, (_uttid, data) in enumerate(wav_scp):
task = asyncio.create_task(ws_rec(data, args.ws_uri))
tasks.append((_uttid, task))
if len(tasks) < args.num_concurrence:
continue
print((f'{i=}, start {args.num_concurrence} '
f'queries @ {time.strftime("%m-%d %H:%M:%S")}'))
for uttid, task in tasks:
result = await task
texts.append(f'{uttid}\t{result["text"]}\n')
request_times.append(result['time'])
tasks = []
print(f'\tdone @ {time.strftime("%m-%d %H:%M:%S")}')
if tasks:
for uttid, task in tasks:
result = await task
texts.append(f'{uttid}\t{result["text"]}\n')
request_times.append(result['time'])
request_time = time.time() - begin
rtf = request_time / total_duration
print('For all concurrence:')
print_result({
'failed': failed,
'total_duration': total_duration,
'request_time': request_time,
'RTF': rtf,
})
print('For one request:')
print_result({
'mean': statistics.mean(request_times),
'median': statistics.median(request_times),
'max_time': max(request_times),
'min_time': min(request_times),
})
with open(args.save_to, 'w', encoding='utf8') as fsave:
fsave.write(''.join(texts))
# caculate CER
cmd = (f'python ../compute-wer.py --char=1 --v=1 '
f'{args.trans} {args.save_to} > '
f'{args.save_to}-test-{args.num_concurrence}.cer.txt')
print(cmd)
os.system(cmd)
print('done')
if __name__ == '__main__':
args = get_args()
asyncio.run(main(args))
../../../wenet/
\ No newline at end of file
# Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu)
# 2022 Tinnove Inc (authors: Wei Ren)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import torch
import yaml
from torch.utils.data import DataLoader
from textgrid import TextGrid, IntervalTier
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.ctc_util import forced_align
from wenet.utils.common import get_subsample
from wenet.utils.init_model import init_model
def generator_textgrid(maxtime, lines, output):
# Download Praat: https://www.fon.hum.uva.nl/praat/
interval = maxtime / (len(lines) + 1)
margin = 0.0001
tg = TextGrid(maxTime=maxtime)
linetier = IntervalTier(name="line", maxTime=maxtime)
i = 0
for l in lines:
s, e, w = l.split()
linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w)
tg.append(linetier)
print("successfully generator {}".format(output))
tg.write(output)
def get_frames_timestamp(alignment):
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
timestamp = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == 0:
end += 1
if end == len(alignment):
timestamp[-1] += alignment[start:]
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[end]:
end += 1
timestamp.append(alignment[start:end])
start = end
return timestamp
def get_labformat(timestamp, subsample):
begin = 0
duration = 0
labformat = []
for idx, t in enumerate(timestamp):
# 25ms frame_length,10ms hop_length, 1/subsample
subsample = get_subsample(configs)
# time duration
duration = len(t) * 0.01 * subsample
if idx < len(timestamp) - 1:
print("{:.2f} {:.2f} {}".format(begin, begin + duration,
char_dict[t[-1]]))
labformat.append("{:.2f} {:.2f} {}\n".format(
begin, begin + duration, char_dict[t[-1]]))
else:
non_blank = 0
for i in t:
if i != 0:
token = i
break
print("{:.2f} {:.2f} {}".format(begin, begin + duration,
char_dict[token]))
labformat.append("{:.2f} {:.2f} {}\n".format(
begin, begin + duration, char_dict[token]))
begin = begin + duration
return labformat
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='use ctc to generate alignment')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--input_file', required=True, help='format data file')
parser.add_argument('--data_type',
default='raw',
choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument('--non_lang_syms',
help="non-linguistic symbol file. One symbol per line.")
parser.add_argument('--result_file',
required=True,
help='alignment result file')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--gen_praat',
action='store_true',
help='convert alignment to a praat format')
parser.add_argument('--bpe_model',
default=None,
type=str,
help='bpe model for english part')
args = parser.parse_args()
print(args)
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
if args.batch_size > 1:
logging.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
# Load dict
char_dict = {}
with open(args.dict, 'r') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
eos = len(char_dict) - 1
symbol_table = read_symbol_table(args.dict)
# Init dataset and data loader
ali_conf = copy.deepcopy(configs['dataset_conf'])
ali_conf['filter_conf']['max_length'] = 102400
ali_conf['filter_conf']['min_length'] = 0
ali_conf['filter_conf']['token_max_length'] = 102400
ali_conf['filter_conf']['token_min_length'] = 0
ali_conf['filter_conf']['max_output_input_ratio'] = 102400
ali_conf['filter_conf']['min_output_input_ratio'] = 0
ali_conf['speed_perturb'] = False
ali_conf['spec_aug'] = False
ali_conf['shuffle'] = False
ali_conf['sort'] = False
ali_conf['fbank_conf']['dither'] = 0.0
ali_conf['batch_conf']['batch_type'] = "static"
ali_conf['batch_conf']['batch_size'] = args.batch_size
non_lang_syms = read_non_lang_symbols(args.non_lang_syms)
ali_dataset = Dataset(args.data_type,
args.input_file,
symbol_table,
ali_conf,
args.bpe_model,
non_lang_syms,
partition=False)
ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0)
# Init asr model from configs
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
with torch.no_grad(), open(args.result_file, 'w',
encoding='utf-8') as fout:
for batch_idx, batch in enumerate(ali_data_loader):
print("#" * 80)
key, feat, target, feats_length, target_length = batch
print(key)
feat = feat.to(device)
target = target.to(device)
feats_length = feats_length.to(device)
target_length = target_length.to(device)
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out, encoder_mask = model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# print(ctc_probs.size(1))
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = forced_align(ctc_probs, target)
print(alignment)
fout.write('{} {}\n'.format(key[0], alignment))
if args.gen_praat:
timestamp = get_frames_timestamp(alignment)
print(timestamp)
subsample = get_subsample(configs)
labformat = get_labformat(timestamp, subsample)
lab_path = os.path.join(os.path.dirname(args.result_file),
key[0] + ".lab")
with open(lab_path, 'w', encoding='utf-8') as f:
f.writelines(labformat)
textgrid_path = os.path.join(os.path.dirname(args.result_file),
key[0] + ".TextGrid")
generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 *
subsample,
lines=labformat,
output=textgrid_path)
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import glob
import yaml
import numpy as np
import torch
def get_args():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument('--src_path',
required=True,
help='src model path for average')
parser.add_argument('--val_best',
action="store_true",
help='averaged model')
parser.add_argument('--num',
default=5,
type=int,
help='nums for averaged model')
parser.add_argument('--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument('--max_epoch',
default=65536,
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
checkpoints = []
val_scores = []
if args.val_best:
yamls = glob.glob('{}/[!train]*.yaml'.format(args.src_path))
for y in yamls:
with open(y, 'r') as f:
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
loss = dic_yaml['cv_loss']
epoch = dic_yaml['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores += [[epoch, loss]]
val_scores = np.array(val_scores)
sort_idx = np.argsort(val_scores[:, -1])
sorted_val_scores = val_scores[sort_idx][::1]
print("best val scores = " + str(sorted_val_scores[:args.num, 1]))
print("selected epochs = " +
str(sorted_val_scores[:args.num, 0].astype(np.int64)))
path_list = [
args.src_path + '/{}.pt'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
else:
path_list = glob.glob('{}/[0-9]*.pt'.format(args.src_path))
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num:]
print(path_list)
avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu'))
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
avg[k] = torch.true_divide(avg[k], num)
print('Saving to {}'.format(args.dst_model))
torch.save(avg, args.dst_model)
if __name__ == '__main__':
main()
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import os
import torch
import yaml
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
def get_args():
parser = argparse.ArgumentParser(description='export your script model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--output_file', default=None, help='output file')
parser.add_argument('--output_quant_file',
default=None,
help='output quantized model file')
args = parser.parse_args()
return args
def main():
args = get_args()
# No need gpu for model export
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
print(model)
load_checkpoint(model, args.checkpoint)
# Export jit torch script model
if args.output_file:
script_model = torch.jit.script(model)
script_model.save(args.output_file)
print('Export model successfully, see {}'.format(args.output_file))
# Export quantized jit torch script model
if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file)
print('Export quantized model successfully, '
'see {}'.format(args.output_quant_file))
if __name__ == '__main__':
main()
# Copyright (c) 2022, Horizon Inc. Xingchen Song (sxc19@tsinghua.org.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NOTE(xcsong): Currently, we only support
1. specific conformer encoder architecture, see:
encoder: conformer
encoder_conf:
activation_type: **must be** relu
attention_heads: 2 or 4 or 8 or any number divisible by output_size
causal: **must be** true
cnn_module_kernel: 1 ~ 7
cnn_module_norm: **must be** batch_norm
input_layer: **must be** conv2d8
linear_units: 1 ~ 2048
normalize_before: **must be** true
num_blocks: 1 ~ 12
output_size: 1 ~ 512
pos_enc_layer_type: **must be** no_pos
selfattention_layer_type: **must be** selfattn
use_cnn_module: **must be** true
use_dynamic_chunk: **must be** true
use_dynamic_left_chunk: **must be** true
2. specific decoding method: ctc_greedy_search
"""
from __future__ import print_function
import os
import sys
import copy
import math
import yaml
import logging
from typing import Tuple
import torch
import numpy as np
from wenet.transformer.embedding import NoPositionalEncoding
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
from wenet.bin.export_onnx_cpu import (get_args, to_numpy,
print_input_output_info)
try:
import onnx
import onnxruntime
except ImportError:
print('Please install onnx and onnxruntime!')
sys.exit(1)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
class BPULayerNorm(torch.nn.Module):
"""Refactor torch.nn.LayerNorm to meet 4-D dataflow."""
def __init__(self, module, chunk_size=8, run_on_bpu=False):
super().__init__()
original = copy.deepcopy(module)
self.hidden = module.weight.size(0)
self.chunk_size = chunk_size
self.run_on_bpu = run_on_bpu
if self.run_on_bpu:
self.weight = torch.nn.Parameter(
module.weight.reshape(1, self.hidden, 1, 1).repeat(
1, 1, 1, chunk_size))
self.bias = torch.nn.Parameter(
module.bias.reshape(1, self.hidden, 1, 1).repeat(
1, 1, 1, chunk_size))
self.negtive = torch.nn.Parameter(
torch.ones((1, self.hidden, 1, chunk_size)) * -1.0)
self.eps = torch.nn.Parameter(
torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps)
self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
self.mean_conv_1.weight = torch.nn.Parameter(
torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden))
self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
self.mean_conv_2.weight = torch.nn.Parameter(
torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden))
else:
self.norm = module
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, self.chunk_size, self.hidden)
orig_out = module(random_data)
new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2))
np.testing.assert_allclose(
to_numpy(orig_out), to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.run_on_bpu:
u = self.mean_conv_1(x) # (1, h, 1, c)
numerator = x + u * self.negtive # (1, h, 1, c)
s = torch.pow(numerator, 2) # (1, h, 1, c)
s = self.mean_conv_2(s) # (1, h, 1, c)
denominator = torch.sqrt(s + self.eps) # (1, h, 1, c)
x = torch.div(numerator, denominator) # (1, h, 1, c)
x = x * self.weight + self.bias
else:
x = x.squeeze(2).transpose(1, 2).contiguous()
x = self.norm(x)
x = x.transpose(1, 2).contiguous().unsqueeze(2)
return x
class BPUIdentity(torch.nn.Module):
"""Refactor torch.nn.Identity().
For inserting BPU node whose input == output.
"""
def __init__(self, channels):
super().__init__()
self.channels = channels
self.identity_conv = torch.nn.Conv2d(
channels, channels, 1, groups=channels, bias=False)
torch.nn.init.dirac_(
self.identity_conv.weight.data, groups=channels)
self.check_equal()
def check_equal(self):
random_data = torch.randn(1, self.channels, 1, 10)
result = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(random_data), to_numpy(result),
rtol=1e-02, atol=1e-03)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Identity with 4-D dataflow, input == output.
Args:
x (torch.Tensor): (batch, in_channel, 1, time)
Returns:
(torch.Tensor): (batch, in_channel, 1, time).
"""
return self.identity_conv(x)
class BPULinear(torch.nn.Module):
"""Refactor torch.nn.Linear or pointwise_conv"""
def __init__(self, module, is_pointwise_conv=False):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.idim = module.weight.size(1)
self.odim = module.weight.size(0)
self.is_pointwise_conv = is_pointwise_conv
# Modify weight & bias
self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1)
if is_pointwise_conv:
# (odim, idim, kernel=1) -> (odim, idim, 1, 1)
self.linear.weight = torch.nn.Parameter(
module.weight.unsqueeze(-1))
else:
# (odim, idim) -> (odim, idim, 1, 1)
self.linear.weight = torch.nn.Parameter(
module.weight.unsqueeze(2).unsqueeze(3))
self.linear.bias = module.bias
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 8, self.idim)
if self.is_pointwise_conv:
random_data = random_data.transpose(1, 2)
original_result = module(random_data)
if self.is_pointwise_conv:
random_data = random_data.transpose(1, 2)
original_result = original_result.transpose(1, 2)
random_data = random_data.transpose(1, 2).unsqueeze(2)
new_result = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(original_result),
to_numpy(new_result.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Linear with 4-D dataflow.
Args:
x (torch.Tensor): (batch, in_channel, 1, time)
Returns:
(torch.Tensor): (batch, out_channel, 1, time).
"""
return self.linear(x)
class BPUGlobalCMVN(torch.nn.Module):
"""Refactor wenet/transformer/cmvn.py::GlobalCMVN"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
self.norm_var = module.norm_var
# NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1)
self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""CMVN with 4-D dataflow.
Args:
x (torch.Tensor): (batch, 1, mel_dim, time)
Returns:
(torch.Tensor): normalized feature with same shape.
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
class BPUConv2dSubsampling8(torch.nn.Module):
"""Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8
NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding
"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.right_context = module.right_context
self.subsampling_rate = module.subsampling_rate
assert isinstance(module.pos_enc, NoPositionalEncoding)
# 1. Modify self.conv
# NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim)
# to (1, 1, mel_dim, frames) for more efficient computation.
self.conv = module.conv
for idx in [0, 2, 4]:
self.conv[idx].weight = torch.nn.Parameter(
module.conv[idx].weight.transpose(2, 3)
)
# 2. Modify self.linear
# NOTE(xcsong): Split final projection to meet the requirment of
# maximum kernel_size (7 for XJ3)
self.linear = torch.nn.ModuleList()
odim = module.linear.weight.size(0) # 512, in this case
freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9
self.odim, self.freq = odim, freq
weight = module.linear.weight.reshape(
odim, odim, freq, 1) # (odim, odim * freq) -> (odim, odim, freq, 1)
self.split_size = []
num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7
slice_begin = 0
for idx in range(num_split):
kernel_size = min(freq, (idx + 1) * 7) - idx * 7
conv_ele = torch.nn.Conv2d(
odim, odim, (kernel_size, 1), (kernel_size, 1))
conv_ele.weight = torch.nn.Parameter(
weight[:, :, slice_begin:slice_begin + kernel_size, :]
)
conv_ele.bias = torch.nn.Parameter(
torch.zeros_like(conv_ele.bias)
)
self.linear.append(conv_ele)
self.split_size.append(kernel_size)
slice_begin += kernel_size
self.linear[0].bias = torch.nn.Parameter(module.linear.bias)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 67, 80)
mask = torch.zeros(1, 1, 67)
original_result, _, _ = module(random_data, mask) # (1, 8, 512)
random_data = random_data.transpose(1, 2).unsqueeze(0) # (1, 1, 80, 67)
new_result = self.forward(random_data) # (1, 512, 1, 8)
np.testing.assert_allclose(
to_numpy(original_result),
to_numpy(new_result.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x with 4-D dataflow.
Args:
x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, odim, 1, time'),
where time' = time // 8.
"""
x = self.conv(x) # (1, odim, freq, time')
x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3))
x = torch.split(x, self.split_size, dim=2)
for idx, (x_part, layer) in enumerate(zip(x, self.linear)):
x_out += layer(x_part)
return x_out
class BPUMultiHeadedAttention(torch.nn.Module):
"""Refactor wenet/transformer/attention.py::MultiHeadedAttention
NOTE(xcsong): Only support attention_class == MultiHeadedAttention,
we do not consider RelPositionMultiHeadedAttention currently.
"""
def __init__(self, module, chunk_size, left_chunks):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.d_k = module.d_k
self.h = module.h
n_feat = self.d_k * self.h
self.chunk_size = chunk_size
self.left_chunks = left_chunks
self.time = chunk_size * (left_chunks + 1)
self.activation = torch.nn.Softmax(dim=-1)
# 1. Modify self.linear_x
self.linear_q = BPULinear(module.linear_q)
self.linear_k = BPULinear(module.linear_k)
self.linear_v = BPULinear(module.linear_v)
self.linear_out = BPULinear(module.linear_out)
# 2. denom
self.register_buffer(
"denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k)))
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, self.chunk_size, self.d_k * self.h)
mask = torch.ones((1, self.h, self.chunk_size, self.time),
dtype=torch.bool)
cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks,
self.d_k * 2)
original_out, original_cache = module(
random_data, random_data, random_data,
mask[:, 0, :, :], torch.empty(0), cache)
random_data = random_data.transpose(1, 2).unsqueeze(2)
cache = cache.reshape(1, self.h, self.d_k * 2,
self.chunk_size * self.left_chunks)
new_out, new_cache = self.forward(
random_data, random_data, random_data, mask, cache)
np.testing.assert_allclose(
to_numpy(original_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
to_numpy(original_cache),
to_numpy(new_cache.transpose(2, 3)),
rtol=1e-02, atol=1e-03)
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
mask: torch.Tensor, cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size).
k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size).
v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size).
mask (torch.Tensor): Mask tensor,
(#batch, head, chunk_size, cache_t + chunk_size).
cache (torch.Tensor): Cache tensor
(1, head, d_k * 2, cache_t),
where `cache_t == chunk_size * left_chunks`.
Returns:
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
torch.Tensor: Cache tensor
(1, head, d_k * 2, cache_t + chunk_size)
where `cache_t == chunk_size * left_chunks`
"""
# 1. Forward QKV
q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size
k = self.linear_k(k) # (1, d, 1, c)
v = self.linear_v(v) # (1, d, 1, c)
q = q.view(1, self.h, self.d_k, self.chunk_size)
k = k.view(1, self.h, self.d_k, self.chunk_size)
v = v.view(1, self.h, self.d_k, self.chunk_size)
q = q.transpose(2, 3) # (batch, head, time1, d_k)
k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2)
k = torch.cat((k_cache, k), dim=3)
v = torch.cat((v_cache, v), dim=3)
new_cache = torch.cat((k, v), dim=2)
# 2. (Q^T)K
scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2)
# 3. Forward attention
mask = mask.eq(0)
scores = scores.masked_fill(mask, -float('inf'))
attn = self.activation(scores).masked_fill(mask, 0.0)
attn = attn.transpose(2, 3)
x = torch.matmul(v, attn)
x = x.view(1, self.d_k * self.h, 1, self.chunk_size)
x_out = self.linear_out(x)
return x_out, new_cache
class BPUConvolution(torch.nn.Module):
"""Refactor wenet/transformer/convolution.py::ConvolutionModule
NOTE(xcsong): Only suport use_layer_norm == False
"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.lorder = module.lorder
self.use_layer_norm = False
self.activation = module.activation
channels = module.pointwise_conv1.weight.size(1)
self.channels = channels
kernel_size = module.depthwise_conv.weight.size(2)
assert module.use_layer_norm is False
# 1. Modify self.pointwise_conv1
self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True)
# 2. Modify self.depthwise_conv
self.depthwise_conv = torch.nn.Conv2d(
channels, channels, (1, kernel_size),
stride=1, groups=channels)
self.depthwise_conv.weight = torch.nn.Parameter(
module.depthwise_conv.weight.unsqueeze(-2))
self.depthwise_conv.bias = torch.nn.Parameter(
module.depthwise_conv.bias)
# 3. Modify self.norm, Only support batchnorm2d
self.norm = torch.nn.BatchNorm2d(channels)
self.norm.training = False
self.norm.num_features = module.norm.num_features
self.norm.eps = module.norm.eps
self.norm.momentum = module.norm.momentum
self.norm.weight = torch.nn.Parameter(module.norm.weight)
self.norm.bias = torch.nn.Parameter(module.norm.bias)
self.norm.running_mean = module.norm.running_mean
self.norm.running_var = module.norm.running_var
# 4. Modify self.pointwise_conv2
self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True)
# 5. Identity conv, for running `concat` on BPU
self.identity = BPUIdentity(channels)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 8, self.channels)
cache = torch.zeros((1, self.channels, self.lorder))
original_out, original_cache = module(random_data, cache=cache)
random_data = random_data.transpose(1, 2).unsqueeze(2)
cache = cache.unsqueeze(2)
new_out, new_cache = self.forward(random_data, cache)
np.testing.assert_allclose(
to_numpy(original_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
to_numpy(original_cache),
to_numpy(new_cache.squeeze(2)),
rtol=1e-02, atol=1e-03)
def forward(
self, x: torch.Tensor, cache: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size).
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, 1, cache_t).
Returns:
torch.Tensor: Output tensor (#batch, channels, 1, chunk_size).
torch.Tensor: Cache tensor (#batch, channels, 1, cache_t).
"""
# Concat cache
x = torch.cat((self.identity(cache), self.identity(x)), dim=3)
new_cache = x[:, :, :, -self.lorder:]
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim)
x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim)
# Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x, new_cache
class BPUFFN(torch.nn.Module):
"""Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward
"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.activation = module.activation
# 1. Modify self.w_x
self.w_1 = BPULinear(module.w_1)
self.w_2 = BPULinear(module.w_2)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 8, self.w_1.idim)
original_out = module(random_data)
random_data = random_data.transpose(1, 2).unsqueeze(2)
new_out = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(original_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, D, 1, L)
Returns:
output tensor, (B, D, 1, L)
"""
return self.w_2(self.activation(self.w_1(x)))
class BPUConformerEncoderLayer(torch.nn.Module):
"""Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer
"""
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.size = module.size
assert module.normalize_before is True
assert module.concat_after is False
# 1. Modify submodules
self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron)
self.self_attn = BPUMultiHeadedAttention(
module.self_attn, chunk_size, left_chunks)
self.conv_module = BPUConvolution(module.conv_module)
self.feed_forward = BPUFFN(module.feed_forward)
# 2. Modify norms
self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu)
self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, ln_run_on_bpu)
self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron,
chunk_size, ln_run_on_bpu)
self.norm_conv = BPULayerNorm(module.norm_conv,
chunk_size, ln_run_on_bpu)
self.norm_final = BPULayerNorm(module.norm_final,
chunk_size, ln_run_on_bpu)
# 3. 4-D ff_scale
self.register_buffer(
"ff_scale", torch.full((1, self.size, 1, 1), module.ff_scale))
self.check_equal(original)
def check_equal(self, module):
time1 = self.self_attn.chunk_size
time2 = self.self_attn.time
h, d_k = self.self_attn.h, self.self_attn.d_k
random_x = torch.randn(1, time1, self.size)
att_mask = torch.ones(1, h, time1, time2)
att_cache = torch.zeros(1, h, time2 - time1, d_k * 2)
cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder)
original_x, _, original_att_cache, original_cnn_cache = module(
random_x, att_mask[:, 0, :, :], torch.empty(0),
att_cache=att_cache, cnn_cache=cnn_cache
)
random_x = random_x.transpose(1, 2).unsqueeze(2)
att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1)
cnn_cache = cnn_cache.unsqueeze(2)
new_x, new_att_cache, new_cnn_cache = self.forward(
random_x, att_mask, att_cache, cnn_cache
)
np.testing.assert_allclose(
to_numpy(original_att_cache),
to_numpy(new_att_cache.transpose(2, 3)),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
to_numpy(original_x),
to_numpy(new_x.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
to_numpy(original_cnn_cache),
to_numpy(new_cnn_cache.squeeze(2)),
rtol=1e-02, atol=1e-03)
def forward(
self, x: torch.Tensor, att_mask: torch.Tensor,
att_cache: torch.Tensor, cnn_cache: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, size, 1, chunk_size)
att_mask (torch.Tensor): Mask tensor for the input
(#batch, head, chunk_size, cache_t1 + chunk_size),
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, d_k * 2, cache_t1), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, 1, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
torch.Tensor: att_cache tensor,
(1, head, d_k * 2, cache_t1 + chunk_size).
torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2).
"""
# 1. ffn_macaron
residual = x
x = self.norm_ff_macron(x)
x = residual + self.ff_scale * self.feed_forward_macaron(x)
# 2. attention
residual = x
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(
x, x, x, att_mask, att_cache)
x = residual + x_att
# 3. convolution
residual = x
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, cnn_cache)
x = residual + x
# 4. ffn
residual = x
x = self.norm_ff(x)
x = residual + self.ff_scale * self.feed_forward(x)
# 5. final post-norm
x = self.norm_final(x)
return x, new_att_cache, new_cnn_cache
class BPUConformerEncoder(torch.nn.Module):
"""Refactor wenet/transformer/encoder.py::ConformerEncoder
"""
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
output_size = module.output_size()
self._output_size = module.output_size()
self.after_norm = module.after_norm
self.chunk_size = chunk_size
self.left_chunks = left_chunks
self.head = module.encoders[0].self_attn.h
self.layers = len(module.encoders)
# 1. Modify submodules
self.global_cmvn = BPUGlobalCMVN(module.global_cmvn)
self.embed = BPUConv2dSubsampling8(module.embed)
self.encoders = torch.nn.ModuleList()
for layer in module.encoders:
self.encoders.append(BPUConformerEncoderLayer(
layer, chunk_size, left_chunks, ln_run_on_bpu))
# 2. Auxiliary conv
self.identity_cnncache = BPUIdentity(output_size)
self.check_equal(original)
def check_equal(self, module):
time1 = self.encoders[0].self_attn.chunk_size
time2 = self.encoders[0].self_attn.time
layers = self.layers
h, d_k = self.head, self.encoders[0].self_attn.d_k
decoding_window = (self.chunk_size - 1) * \
module.embed.subsampling_rate + \
module.embed.right_context + 1
lorder = self.encoders[0].conv_module.lorder
random_x = torch.randn(1, decoding_window, 80)
att_mask = torch.ones(1, h, time1, time2)
att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2)
cnn_cache = torch.zeros(layers, 1, self._output_size, lorder)
orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk(
random_x, 0, time2 - time1, att_mask=att_mask[:, 0, :, :],
att_cache=att_cache, cnn_cache=cnn_cache
)
random_x = random_x.unsqueeze(0)
att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1)
cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder)
new_x, new_att_cache, new_cnn_cache = self.forward(
random_x, att_cache, cnn_cache, att_mask
)
caches = torch.split(new_att_cache, h, dim=1)
caches = [c.transpose(2, 3) for c in caches]
np.testing.assert_allclose(
to_numpy(orig_att_cache),
to_numpy(torch.cat(caches, dim=0)),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
to_numpy(orig_x),
to_numpy(new_x.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
to_numpy(orig_cnn_cache),
to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
def forward(
self, xs: torch.Tensor, att_cache: torch.Tensor,
cnn_cache: torch.Tensor, att_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(1, head * elayers, d_k * 2, cache_t1), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(1, hidden-dim, elayers, cache_t2), where
`cache_t2 == cnn.lorder - 1`
att_mask (torch.Tensor): Mask tensor for the input
(#batch, head, chunk_size, cache_t1 + chunk_size),
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, hidden-dim, 1, chunk_size).
torch.Tensor: new attention cache required for next chunk, with
same shape as the original att_cache.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
# xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time)
xs = xs.transpose(2, 3)
xs = self.global_cmvn(xs)
# xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size)
xs = self.embed(xs)
att_cache = torch.split(att_cache, self.head, dim=1)
cnn_cache = self.identity_cnncache(cnn_cache)
cnn_cache = torch.split(cnn_cache, 1, dim=2)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
xs, new_att_cache, new_cnn_cache = layer(
xs, att_mask, att_cache=att_cache[i], cnn_cache=cnn_cache[i])
r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:])
r_cnn_cache.append(new_cnn_cache)
r_att_cache = torch.cat(r_att_cache, dim=1)
r_cnn_cache = self.identity_cnncache(
torch.cat(r_cnn_cache, dim=2))
xs = xs.squeeze(2).transpose(1, 2).contiguous()
xs = self.after_norm(xs)
# NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input.
xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T)
return (xs, r_att_cache, r_cnn_cache)
class BPUCTC(torch.nn.Module):
"""Refactor wenet/transformer/ctc.py::CTC
"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.idim = module.ctc_lo.weight.size(1)
num_class = module.ctc_lo.weight.size(0)
# 1. Modify self.ctc_lo, Split final projection to meet the
# requirment of maximum in/out channels (2048 for XJ3)
self.ctc_lo = torch.nn.ModuleList()
self.split_size = []
num_split = (num_class - 1) // 2048 + 1
for idx in range(num_split):
out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048
conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1)
self.ctc_lo.append(conv_ele)
self.split_size.append(out_channel)
orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0)
orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0)
for i, (w, b) in enumerate(zip(orig_weight, orig_bias)):
w = w.unsqueeze(2).unsqueeze(3)
self.ctc_lo[i].weight = torch.nn.Parameter(w)
self.ctc_lo[i].bias = torch.nn.Parameter(b)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 100, self.idim)
original_result = module.ctc_lo(random_data)
random_data = random_data.transpose(1, 2).unsqueeze(2)
new_result = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(original_result),
to_numpy(new_result.squeeze(2).transpose(1, 2)),
rtol=1e-02, atol=1e-03)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""frame activations, without softmax.
Args:
Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size)
Returns:
torch.Tensor: (B, num_class, 1, chunk_size)
"""
out = []
for i, layer in enumerate(self.ctc_lo):
out.append(layer(x))
out = torch.cat(out, dim=1)
return out
def export_encoder(asr_model, args):
logger.info("Stage-1: export encoder")
decode_window, mel_dim = args.decoding_window, args.feature_size
encoder = BPUConformerEncoder(
asr_model.encoder, args.chunk_size, args.num_decoding_left_chunks,
args.ln_run_on_bpu)
encoder.eval()
encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx')
logger.info("Stage-1.1: prepare inputs for encoder")
chunk = torch.randn((1, 1, decode_window, mel_dim))
required_cache_size = encoder.chunk_size * encoder.left_chunks
kv_time = required_cache_size + encoder.chunk_size
hidden, layers = encoder._output_size, len(encoder.encoders)
head = encoder.encoders[0].self_attn.h
d_k = hidden // head
lorder = encoder.encoders[0].conv_module.lorder
att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size)
att_mask = torch.ones((1, head, encoder.chunk_size, kv_time))
att_mask[:, :, :, :required_cache_size] = 0
cnn_cache = torch.zeros((1, hidden, layers, lorder))
inputs = (chunk, att_cache, cnn_cache, att_mask)
logger.info("chunk.size(): {} att_cache.size(): {} "
"cnn_cache.size(): {} att_mask.size(): {}".format(
list(chunk.size()), list(att_cache.size()),
list(cnn_cache.size()), list(att_mask.size())))
logger.info("Stage-1.2: torch.onnx.export")
# NOTE(xcsong): Below attributes will be used in
# onnx2horizonbin.py::generate_config()
attributes = {}
attributes['input_name'] = "chunk;att_cache;cnn_cache;att_mask"
attributes['output_name'] = "output;r_att_cache;r_cnn_cache"
attributes['input_type'] = "featuremap;featuremap;featuremap;featuremap"
attributes['norm_type'] = \
"no_preprocess;no_preprocess;no_preprocess;no_preprocess"
attributes['input_layout_train'] = "NCHW;NCHW;NCHW;NCHW"
attributes['input_layout_rt'] = "NCHW;NCHW;NCHW;NCHW"
attributes['input_shape'] = \
"{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format(
chunk.size(0), chunk.size(1), chunk.size(2), chunk.size(3),
att_cache.size(0), att_cache.size(1), att_cache.size(2),
att_cache.size(3), cnn_cache.size(0), cnn_cache.size(1),
cnn_cache.size(2), cnn_cache.size(3), att_mask.size(0),
att_mask.size(1), att_mask.size(2), att_mask.size(3)
)
torch.onnx.export( # NOTE(xcsong): only support opset==11
encoder, inputs, encoder_outpath, opset_version=11,
export_params=True, do_constant_folding=True,
input_names=attributes['input_name'].split(';'),
output_names=attributes['output_name'].split(';'),
dynamic_axes=None, verbose=False)
onnx_encoder = onnx.load(encoder_outpath)
for k in vars(args):
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(getattr(args, k))
for k in attributes:
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(attributes[k])
onnx.checker.check_model(onnx_encoder)
onnx.helper.printable_graph(onnx_encoder.graph)
onnx.save(onnx_encoder, encoder_outpath)
print_input_output_info(onnx_encoder, "onnx_encoder")
logger.info('Export onnx_encoder, done! see {}'.format(encoder_outpath))
logger.info("Stage-1.3: check onnx_encoder and torch_encoder")
torch_output = []
torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask)
torch_att_cache = copy.deepcopy(att_cache)
torch_cnn_cache = copy.deepcopy(cnn_cache)
for i in range(10):
logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}"
", att_mask: {}".format(
i, list(torch_chunk.size()),
list(torch_att_cache.size()),
list(torch_cnn_cache.size()),
list(torch_att_mask.size())))
torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1
out, torch_att_cache, torch_cnn_cache = encoder(
torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask)
torch_output.append(out)
torch_output = torch.cat(torch_output, dim=-1)
onnx_output = []
onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask)
onnx_att_cache = to_numpy(att_cache)
onnx_cnn_cache = to_numpy(cnn_cache)
ort_session = onnxruntime.InferenceSession(encoder_outpath)
input_names = [node.name for node in onnx_encoder.graph.input]
for i in range(10):
logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {},"
" att_mask: {}".format(
i, onnx_chunk.shape, onnx_att_cache.shape,
onnx_cnn_cache.shape, onnx_att_mask.shape))
onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1
ort_inputs = {
'chunk': onnx_chunk, 'att_cache': onnx_att_cache,
'cnn_cache': onnx_cnn_cache, 'att_mask': onnx_att_mask,
}
ort_outs = ort_session.run(None, ort_inputs)
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
onnx_output.append(ort_outs[0])
onnx_output = np.concatenate(onnx_output, axis=-1)
np.testing.assert_allclose(to_numpy(torch_output), onnx_output,
rtol=1e-03, atol=1e-04)
meta = ort_session.get_modelmeta()
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
logger.info("Check onnx_encoder, pass!")
return encoder, ort_session
def export_ctc(asr_model, args):
logger.info("Stage-2: export ctc")
ctc = BPUCTC(asr_model.ctc).eval()
ctc_outpath = os.path.join(args.output_dir, 'ctc.onnx')
logger.info("Stage-2.1: prepare inputs for ctc")
hidden = torch.randn((1, args.output_size, 1, args.chunk_size))
logger.info("Stage-2.2: torch.onnx.export")
# NOTE(xcsong): Below attributes will be used in
# onnx2horizonbin.py::generate_config()
attributes = {}
attributes['input_name'], attributes['input_type'] = "hidden", "featuremap"
attributes['norm_type'] = "no_preprocess"
attributes['input_layout_train'] = "NCHW"
attributes['input_layout_rt'] = "NCHW"
attributes['input_shape'] = "{}x{}x{}x{}".format(
hidden.size(0), hidden.size(1), hidden.size(2), hidden.size(3),
)
torch.onnx.export(
ctc, hidden, ctc_outpath, opset_version=11,
export_params=True, do_constant_folding=True,
input_names=['hidden'], output_names=['probs'],
dynamic_axes=None, verbose=False)
onnx_ctc = onnx.load(ctc_outpath)
for k in vars(args):
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(getattr(args, k))
for k in attributes:
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(attributes[k])
onnx.checker.check_model(onnx_ctc)
onnx.helper.printable_graph(onnx_ctc.graph)
onnx.save(onnx_ctc, ctc_outpath)
print_input_output_info(onnx_ctc, "onnx_ctc")
logger.info('Export onnx_ctc, done! see {}'.format(ctc_outpath))
logger.info("Stage-2.3: check onnx_ctc and torch_ctc")
torch_output = ctc(hidden)
ort_session = onnxruntime.InferenceSession(ctc_outpath)
onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)})
np.testing.assert_allclose(to_numpy(torch_output), onnx_output[0],
rtol=1e-03, atol=1e-04)
meta = ort_session.get_modelmeta()
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
logger.info("Check onnx_ctc, pass!")
return ctc, ort_session
def export_decoder(asr_model, args):
logger.info("Currently, Decoder is not supported.")
if __name__ == '__main__':
torch.manual_seed(777)
args = get_args()
args.ln_run_on_bpu = False
# NOTE(xcsong): XJ3 BPU only support static shapes
assert args.chunk_size > 0
assert args.num_decoding_left_chunks > 0
os.system("mkdir -p " + args.output_dir)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
model.eval()
print(model)
args.feature_size = configs['input_dim']
args.output_size = model.encoder.output_size()
args.decoding_window = (args.chunk_size - 1) * \
model.encoder.embed.subsampling_rate + \
model.encoder.embed.right_context + 1
export_encoder(model, args)
export_ctc(model, args)
export_decoder(model, args)
# Copyright (c) 2022, Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import os
import copy
import sys
import torch
import yaml
import numpy as np
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
try:
import onnx
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType
except ImportError:
print('Please install onnx and onnxruntime!')
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser(description='export your script model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--output_dir', required=True, help='output directory')
parser.add_argument('--chunk_size', required=True,
type=int, help='decoding chunk size')
parser.add_argument('--num_decoding_left_chunks', required=True,
type=int, help='cache chunks')
parser.add_argument('--reverse_weight', default=0.5,
type=float, help='reverse_weight in attention_rescoing')
args = parser.parse_args()
return args
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
def print_input_output_info(onnx_model, name, prefix="\t\t"):
input_names = [node.name for node in onnx_model.graph.input]
input_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input]
output_names = [node.name for node in onnx_model.graph.output]
output_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output]
print("{}{} inputs : {}".format(prefix, name, input_names))
print("{}{} input shapes : {}".format(prefix, name, input_shapes))
print("{}{} outputs: {}".format(prefix, name, output_names))
print("{}{} output shapes : {}".format(prefix, name, output_shapes))
def export_encoder(asr_model, args):
print("Stage-1: export encoder")
encoder = asr_model.encoder
encoder.forward = encoder.forward_chunk
encoder_outpath = os.path.join(args['output_dir'], 'encoder.onnx')
print("\tStage-1.1: prepare inputs for encoder")
chunk = torch.randn(
(args['batch'], args['decoding_window'], args['feature_size']))
offset = 0
# NOTE(xcsong): The uncertainty of `next_cache_start` only appears
# in the first few chunks, this is caused by dynamic att_cache shape, i,e
# (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent
# chunks. One way to ease the ONNX export is to keep `next_cache_start`
# as a fixed value. To do this, for the **first** chunk, if
# left_chunks > 0, we feed real cache & real mask to the model, otherwise
# fake cache & fake mask. In this way, we get:
# 1. 16/-1 mode: next_cache_start == 0 for all chunks
# 2. 16/4 mode: next_cache_start == chunk_size for all chunks
# 3. 16/0 mode: next_cache_start == chunk_size for all chunks
# 4. -1/-1 mode: next_cache_start == 0 for all chunks
# NO MORE DYNAMIC CHANGES!!
#
# NOTE(Mddct): We retain the current design for the convenience of supporting some
# inference frameworks without dynamic shapes. If you're interested in all-in-one
# model that supports different chunks please see:
# https://github.com/wenet-e2e/wenet/pull/1174
if args['left_chunks'] > 0: # 16/4
required_cache_size = args['chunk_size'] * args['left_chunks']
offset = required_cache_size
# Real cache
att_cache = torch.zeros(
(args['num_blocks'], args['head'], required_cache_size,
args['output_size'] // args['head'] * 2))
# Real mask
att_mask = torch.ones(
(args['batch'], 1, required_cache_size + args['chunk_size']),
dtype=torch.bool)
att_mask[:, :, :required_cache_size] = 0
elif args['left_chunks'] <= 0: # 16/-1, -1/-1, 16/0
required_cache_size = -1 if args['left_chunks'] < 0 else 0
# Fake cache
att_cache = torch.zeros(
(args['num_blocks'], args['head'], 0,
args['output_size'] // args['head'] * 2))
# Fake mask
att_mask = torch.ones((0, 0, 0), dtype=torch.bool)
cnn_cache = torch.zeros(
(args['num_blocks'], args['batch'],
args['output_size'], args['cnn_module_kernel'] - 1))
inputs = (chunk, offset, required_cache_size,
att_cache, cnn_cache, att_mask)
print("\t\tchunk.size(): {}\n".format(chunk.size()),
"\t\toffset: {}\n".format(offset),
"\t\trequired_cache: {}\n".format(required_cache_size),
"\t\tatt_cache.size(): {}\n".format(att_cache.size()),
"\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()),
"\t\tatt_mask.size(): {}\n".format(att_mask.size()))
print("\tStage-1.2: torch.onnx.export")
dynamic_axes = {
'chunk': {1: 'T'},
'att_cache': {2: 'T_CACHE'},
'att_mask': {2: 'T_ADD_T_CACHE'},
'output': {1: 'T'},
'r_att_cache': {2: 'T_CACHE'},
}
# NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is
# to avoid padding the last chunk (which usually contains less
# frames than required). For users who want static axes, just pop
# out specific axis.
# if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0
# dynamic_axes.pop('chunk')
# dynamic_axes.pop('output')
# if args['left_chunks'] >= 0: # 16/4, 16/0
# # NOTE(xsong): since we feed real cache & real mask into the
# # model when left_chunks > 0, the shape of cache will never
# # be changed.
# dynamic_axes.pop('att_cache')
# dynamic_axes.pop('r_att_cache')
torch.onnx.export(
encoder, inputs, encoder_outpath, opset_version=13,
export_params=True, do_constant_folding=True,
input_names=[
'chunk', 'offset', 'required_cache_size',
'att_cache', 'cnn_cache', 'att_mask'
],
output_names=['output', 'r_att_cache', 'r_cnn_cache'],
dynamic_axes=dynamic_axes, verbose=False)
onnx_encoder = onnx.load(encoder_outpath)
for (k, v) in args.items():
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_encoder)
onnx.helper.printable_graph(onnx_encoder.graph)
# NOTE(xcsong): to add those metadatas we need to reopen
# the file and resave it.
onnx.save(onnx_encoder, encoder_outpath)
print_input_output_info(onnx_encoder, "onnx_encoder")
# Dynamic quantization
model_fp32 = encoder_outpath
model_quant = os.path.join(args['output_dir'], 'encoder.quant.onnx')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print('\t\tExport onnx_encoder, done! see {}'.format(encoder_outpath))
print("\tStage-1.3: check onnx_encoder and torch_encoder")
torch_output = []
torch_chunk = copy.deepcopy(chunk)
torch_offset = copy.deepcopy(offset)
torch_required_cache_size = copy.deepcopy(required_cache_size)
torch_att_cache = copy.deepcopy(att_cache)
torch_cnn_cache = copy.deepcopy(cnn_cache)
torch_att_mask = copy.deepcopy(att_mask)
for i in range(10):
print("\t\ttorch chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}".format(
i, list(torch_chunk.size()), torch_offset,
list(torch_att_cache.size()),
list(torch_cnn_cache.size()), list(torch_att_mask.size())))
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if args['left_chunks'] > 0: # 16/4
torch_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1
out, torch_att_cache, torch_cnn_cache = encoder(
torch_chunk, torch_offset, torch_required_cache_size,
torch_att_cache, torch_cnn_cache, torch_att_mask)
torch_output.append(out)
torch_offset += out.size(1)
torch_output = torch.cat(torch_output, dim=1)
onnx_output = []
onnx_chunk = to_numpy(chunk)
onnx_offset = np.array((offset)).astype(np.int64)
onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64)
onnx_att_cache = to_numpy(att_cache)
onnx_cnn_cache = to_numpy(cnn_cache)
onnx_att_mask = to_numpy(att_mask)
ort_session = onnxruntime.InferenceSession(encoder_outpath)
input_names = [node.name for node in onnx_encoder.graph.input]
for i in range(10):
print("\t\tonnx chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}".format(
i, onnx_chunk.shape, onnx_offset, onnx_att_cache.shape,
onnx_cnn_cache.shape, onnx_att_mask.shape))
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if args['left_chunks'] > 0: # 16/4
onnx_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1
ort_inputs = {
'chunk': onnx_chunk, 'offset': onnx_offset,
'required_cache_size': onnx_required_cache_size,
'att_cache': onnx_att_cache, 'cnn_cache': onnx_cnn_cache,
'att_mask': onnx_att_mask
}
# NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start`
# will be hardcoded to 0 or chunk_size by ONNX, thus
# required_cache_size and att_mask are no more needed and they will
# be removed by ONNX automatically.
for k in list(ort_inputs):
if k not in input_names:
ort_inputs.pop(k)
ort_outs = ort_session.run(None, ort_inputs)
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
onnx_output.append(ort_outs[0])
onnx_offset += ort_outs[0].shape[1]
onnx_output = np.concatenate(onnx_output, axis=1)
np.testing.assert_allclose(to_numpy(torch_output), onnx_output,
rtol=1e-03, atol=1e-05)
meta = ort_session.get_modelmeta()
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))
print("\t\tCheck onnx_encoder, pass!")
def export_ctc(asr_model, args):
print("Stage-2: export ctc")
ctc = asr_model.ctc
ctc.forward = ctc.log_softmax
ctc_outpath = os.path.join(args['output_dir'], 'ctc.onnx')
print("\tStage-2.1: prepare inputs for ctc")
hidden = torch.randn(
(args['batch'], args['chunk_size'] if args['chunk_size'] > 0 else 16,
args['output_size']))
print("\tStage-2.2: torch.onnx.export")
dynamic_axes = {'hidden': {1: 'T'}, 'probs': {1: 'T'}}
torch.onnx.export(
ctc, hidden, ctc_outpath, opset_version=13,
export_params=True, do_constant_folding=True,
input_names=['hidden'], output_names=['probs'],
dynamic_axes=dynamic_axes, verbose=False)
onnx_ctc = onnx.load(ctc_outpath)
for (k, v) in args.items():
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_ctc)
onnx.helper.printable_graph(onnx_ctc.graph)
onnx.save(onnx_ctc, ctc_outpath)
print_input_output_info(onnx_ctc, "onnx_ctc")
# Dynamic quantization
model_fp32 = ctc_outpath
model_quant = os.path.join(args['output_dir'], 'ctc.quant.onnx')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print('\t\tExport onnx_ctc, done! see {}'.format(ctc_outpath))
print("\tStage-2.3: check onnx_ctc and torch_ctc")
torch_output = ctc(hidden)
ort_session = onnxruntime.InferenceSession(ctc_outpath)
onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)})
np.testing.assert_allclose(to_numpy(torch_output), onnx_output[0],
rtol=1e-03, atol=1e-05)
print("\t\tCheck onnx_ctc, pass!")
def export_decoder(asr_model, args):
print("Stage-3: export decoder")
decoder = asr_model
# NOTE(lzhin): parameters of encoder will be automatically removed
# since they are not used during rescoring.
decoder.forward = decoder.forward_attention_decoder
decoder_outpath = os.path.join(args['output_dir'], 'decoder.onnx')
print("\tStage-3.1: prepare inputs for decoder")
# hardcode time->200 nbest->10 len->20, they are dynamic axes.
encoder_out = torch.randn((1, 200, args['output_size']))
hyps = torch.randint(low=0, high=args['vocab_size'],
size=[10, 20])
hyps[:, 0] = args['vocab_size'] - 1 # <sos>
hyps_lens = torch.randint(low=15, high=21, size=[10])
print("\tStage-3.2: torch.onnx.export")
dynamic_axes = {
'hyps': {0: 'NBEST', 1: 'L'}, 'hyps_lens': {0: 'NBEST'},
'encoder_out': {1: 'T'},
'score': {0: 'NBEST', 1: 'L'}, 'r_score': {0: 'NBEST', 1: 'L'}
}
inputs = (hyps, hyps_lens, encoder_out, args['reverse_weight'])
torch.onnx.export(
decoder, inputs, decoder_outpath, opset_version=13,
export_params=True, do_constant_folding=True,
input_names=['hyps', 'hyps_lens', 'encoder_out', 'reverse_weight'],
output_names=['score', 'r_score'],
dynamic_axes=dynamic_axes, verbose=False)
onnx_decoder = onnx.load(decoder_outpath)
for (k, v) in args.items():
meta = onnx_decoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_decoder)
onnx.helper.printable_graph(onnx_decoder.graph)
onnx.save(onnx_decoder, decoder_outpath)
print_input_output_info(onnx_decoder, "onnx_decoder")
model_fp32 = decoder_outpath
model_quant = os.path.join(args['output_dir'], 'decoder.quant.onnx')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print('\t\tExport onnx_decoder, done! see {}'.format(
decoder_outpath))
print("\tStage-3.3: check onnx_decoder and torch_decoder")
torch_score, torch_r_score = decoder(
hyps, hyps_lens, encoder_out, args['reverse_weight'])
ort_session = onnxruntime.InferenceSession(decoder_outpath)
input_names = [node.name for node in onnx_decoder.graph.input]
ort_inputs = {
'hyps': to_numpy(hyps),
'hyps_lens': to_numpy(hyps_lens),
'encoder_out': to_numpy(encoder_out),
'reverse_weight': np.array((args['reverse_weight'])),
}
for k in list(ort_inputs):
if k not in input_names:
ort_inputs.pop(k)
onnx_output = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(to_numpy(torch_score), onnx_output[0],
rtol=1e-03, atol=1e-05)
if args['is_bidirectional_decoder'] and args['reverse_weight'] > 0.0:
np.testing.assert_allclose(to_numpy(torch_r_score), onnx_output[1],
rtol=1e-03, atol=1e-05)
print("\t\tCheck onnx_decoder, pass!")
def main():
torch.manual_seed(777)
args = get_args()
output_dir = args.output_dir
os.system("mkdir -p " + output_dir)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
model.eval()
print(model)
arguments = {}
arguments['output_dir'] = output_dir
arguments['batch'] = 1
arguments['chunk_size'] = args.chunk_size
arguments['left_chunks'] = args.num_decoding_left_chunks
arguments['reverse_weight'] = args.reverse_weight
arguments['output_size'] = configs['encoder_conf']['output_size']
arguments['num_blocks'] = configs['encoder_conf']['num_blocks']
arguments['cnn_module_kernel'] = configs['encoder_conf'].get('cnn_module_kernel', 1)
arguments['head'] = configs['encoder_conf']['attention_heads']
arguments['feature_size'] = configs['input_dim']
arguments['vocab_size'] = configs['output_dim']
# NOTE(xcsong): if chunk_size == -1, hardcode to 67
arguments['decoding_window'] = (args.chunk_size - 1) * \
model.encoder.embed.subsampling_rate + \
model.encoder.embed.right_context + 1 if args.chunk_size > 0 else 67
arguments['encoder'] = configs['encoder']
arguments['decoder'] = configs['decoder']
arguments['subsampling_rate'] = model.subsampling_rate()
arguments['right_context'] = model.right_context()
arguments['sos_symbol'] = model.sos_symbol()
arguments['eos_symbol'] = model.eos_symbol()
arguments['is_bidirectional_decoder'] = 1 \
if model.is_bidirectional_decoder() else 0
# NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is
# not a [16/4 16/-1 16/0] all-in-one model and it should not be used in
# streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you
# want to use 16/-1 or any other streaming mode in `decoder_main`,
# please export onnx in the same config.
if arguments['left_chunks'] > 0:
assert arguments['chunk_size'] > 0 # -1/4 not supported
export_encoder(model, arguments)
export_ctc(model, arguments)
export_decoder(model, arguments)
if __name__ == '__main__':
main()
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import os
import sys
import torch
import yaml
import logging
from wenet.utils.checkpoint import load_checkpoint
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
from wenet.utils.init_model import init_model
from wenet.utils.mask import make_pad_mask
try:
import onnxruntime
except ImportError:
print('Please install onnxruntime-gpu!')
sys.exit(1)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
class Encoder(torch.nn.Module):
def __init__(self,
encoder: BaseEncoder,
ctc: CTC,
beam_size: int = 10):
super().__init__()
self.encoder = encoder
self.ctc = ctc
self.beam_size = beam_size
def forward(self, speech: torch.Tensor,
speech_lengths: torch.Tensor,):
"""Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
Returns:
encoder_out: B x T x F
encoder_out_lens: B
ctc_log_probs: B x T x V
beam_log_probs: B x T x beam_size
beam_log_probs_idx: B x T x beam_size
"""
encoder_out, encoder_mask = self.encoder(speech,
speech_lengths,
-1, -1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
ctc_log_probs = self.ctc.log_softmax(encoder_out)
encoder_out_lens = encoder_out_lens.int()
beam_log_probs, beam_log_probs_idx = torch.topk(
ctc_log_probs, self.beam_size, dim=2)
return encoder_out, encoder_out_lens, ctc_log_probs, \
beam_log_probs, beam_log_probs_idx
class StreamingEncoder(torch.nn.Module):
def __init__(self, model, required_cache_size, beam_size, transformer=False):
super().__init__()
self.ctc = model.ctc
self.subsampling_rate = model.encoder.embed.subsampling_rate
self.embed = model.encoder.embed
self.global_cmvn = model.encoder.global_cmvn
self.required_cache_size = required_cache_size
self.beam_size = beam_size
self.encoder = model.encoder
self.transformer = transformer
def forward(self, chunk_xs, chunk_lens, offset,
att_cache, cnn_cache, cache_mask):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset = offset.squeeze(1)
T = chunk_xs.size(1)
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
# B X 1 X T
chunk_mask = chunk_mask.to(chunk_xs.dtype)
# transpose batch & num_layers dim
att_cache = torch.transpose(att_cache, 0, 1)
cnn_cache = torch.transpose(cnn_cache, 0, 1)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs = self.global_cmvn(chunk_xs)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
cache_size = att_cache.size(3) # required cache size
masks = torch.cat((cache_mask, chunk_mask), dim=2)
index = offset - cache_size
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
pos_emb = pos_emb.to(dtype=xs.dtype)
next_cache_start = -self.required_cache_size
r_cache_mask = masks[:, :, next_cache_start:]
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoder.encoders):
xs, _, new_att_cache, new_cnn_cache = layer(
xs, masks, pos_emb,
att_cache=att_cache[i],
cnn_cache=cnn_cache[i])
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
r_att_cache.append(
new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
if not self.transformer:
r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
if self.encoder.normalize_before:
chunk_out = self.encoder.after_norm(xs)
else:
chunk_out = xs
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
if not self.transformer:
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
# <---------forward_chunk END--------->
log_ctc_probs = self.ctc.log_softmax(chunk_out)
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
self.beam_size,
dim=2)
log_probs = log_probs.to(chunk_xs.dtype)
r_offset = offset + chunk_out.shape[1]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens = chunk_lens // self.subsampling_rate
r_offset = r_offset.unsqueeze(1)
return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \
r_offset, r_att_cache, r_cnn_cache, r_cache_mask
class StreamingSqueezeformerEncoder(torch.nn.Module):
def __init__(self, model, required_cache_size, beam_size):
super().__init__()
self.ctc = model.ctc
self.subsampling_rate = model.encoder.embed.subsampling_rate
self.embed = model.encoder.embed
self.global_cmvn = model.encoder.global_cmvn
self.required_cache_size = required_cache_size
self.beam_size = beam_size
self.encoder = model.encoder
self.reduce_idx = model.encoder.reduce_idx
self.recover_idx = model.encoder.recover_idx
if self.reduce_idx is None:
self.time_reduce = None
else:
if self.recover_idx is None:
self.time_reduce = 'normal' # no recovery at the end
else:
self.time_reduce = 'recover' # recovery at the end
assert len(self.reduce_idx) == len(self.recover_idx)
def calculate_downsampling_factor(self, i: int) -> int:
if self.reduce_idx is None:
return 1
else:
reduce_exp, recover_exp = 0, 0
for exp, rd_idx in enumerate(self.reduce_idx):
if i >= rd_idx:
reduce_exp = exp + 1
if self.recover_idx is not None:
for exp, rc_idx in enumerate(self.recover_idx):
if i >= rc_idx:
recover_exp = exp + 1
return int(2 ** (reduce_exp - recover_exp))
def forward(self, chunk_xs, chunk_lens, offset,
att_cache, cnn_cache, cache_mask):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset = offset.squeeze(1)
T = chunk_xs.size(1)
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
# B X 1 X T
chunk_mask = chunk_mask.to(chunk_xs.dtype)
# transpose batch & num_layers dim
att_cache = torch.transpose(att_cache, 0, 1)
cnn_cache = torch.transpose(cnn_cache, 0, 1)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs = self.global_cmvn(chunk_xs)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
elayers, cache_size = att_cache.size(0), att_cache.size(3)
att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
index = offset - cache_size
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
pos_emb = pos_emb.to(dtype=xs.dtype)
next_cache_start = -self.required_cache_size
r_cache_mask = att_mask[:, :, next_cache_start:]
r_att_cache = []
r_cnn_cache = []
mask_pad = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
mask_pad = mask_pad.unsqueeze(1)
max_att_len: int = 0
recover_activations: \
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
index = 0
xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
xs = self.encoder.preln(xs)
for i, layer in enumerate(self.encoder.encoders):
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append(
(xs, att_mask, pos_emb, mask_pad))
xs, xs_lens, att_mask, mask_pad = \
self.encoder.time_reduction_layer(
xs, xs_lens, att_mask, mask_pad)
pos_emb = pos_emb[:, ::2, :]
if self.encoder.pos_enc_layer_type == "rel_pos_repaired":
pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :]
index += 1
if self.recover_idx is not None:
if self.time_reduce == 'recover' and i in self.recover_idx:
index -= 1
(recover_tensor, recover_att_mask,
recover_pos_emb, recover_mask_pad) \
= recover_activations[index]
# recover output length for ctc decode
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
xs = self.encoder.time_recover_layer(xs)
recoverd_t = recover_tensor.size(1)
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
att_mask = recover_att_mask
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad
factor = self.calculate_downsampling_factor(i)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i][:, :, ::factor, :]
[:, :, :pos_emb.size(1) - xs.size(1), :] if
elayers > 0 else att_cache[:, :, ::factor, :],
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
)
cached_att \
= new_att_cache[:, :, next_cache_start // factor:, :]
cached_cnn = new_cnn_cache.unsqueeze(1)
cached_att = cached_att.unsqueeze(3). \
repeat(1, 1, 1, factor, 1).flatten(2, 3)
if i == 0:
# record length for the first block as max length
max_att_len = cached_att.size(2)
r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1))
r_cnn_cache.append(cached_cnn)
chunk_out = xs
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
# <---------forward_chunk END--------->
log_ctc_probs = self.ctc.log_softmax(chunk_out)
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
self.beam_size,
dim=2)
log_probs = log_probs.to(chunk_xs.dtype)
r_offset = offset + chunk_out.shape[1]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens = chunk_lens // self.subsampling_rate
r_offset = r_offset.unsqueeze(1)
return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \
r_offset, r_att_cache, r_cnn_cache, r_cache_mask
class Decoder(torch.nn.Module):
def __init__(self,
decoder: TransformerDecoder,
ctc_weight: float = 0.5,
reverse_weight: float = 0.0,
beam_size: int = 10,
decoder_fastertransformer: bool = False):
super().__init__()
self.decoder = decoder
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.beam_size = beam_size
self.decoder_fastertransformer = decoder_fastertransformer
def forward(self,
encoder_out: torch.Tensor,
encoder_lens: torch.Tensor,
hyps_pad_sos_eos: torch.Tensor,
hyps_lens_sos: torch.Tensor,
r_hyps_pad_sos_eos: torch.Tensor,
ctc_score: torch.Tensor):
"""Encoder
Args:
encoder_out: B x T x F
encoder_lens: B
hyps_pad_sos_eos: B x beam x (T2+1),
hyps with sos & eos and padded by ignore id
hyps_lens_sos: B x beam, length for each hyp with sos
r_hyps_pad_sos_eos: B x beam x (T2+1),
reversed hyps with sos & eos and padded by ignore id
ctc_score: B x beam, ctc score for each hyp
Returns:
decoder_out: B x beam x T2 x V
r_decoder_out: B x beam x T2 x V
best_index: B
"""
B, T, F = encoder_out.shape
bz = self.beam_size
B2 = B * bz
encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
T2 = hyps_pad_sos_eos.shape[2] - 1
hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
hyps_lens = hyps_lens_sos.view(B2,)
hyps_pad_sos = hyps_pad[:, :-1].contiguous()
hyps_pad_eos = hyps_pad[:, 1:].contiguous()
r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos,
self.reverse_weight)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
V = decoder_out.shape[-1]
decoder_out = decoder_out.view(B2, T2, V)
mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2
# mask index, remove ignore id
index = torch.unsqueeze(hyps_pad_eos * mask, 2)
score = decoder_out.gather(2, index).squeeze(2) # B2 X T2
# mask padded part
score = score * mask
decoder_out = decoder_out.view(B, bz, T2, V)
if self.reverse_weight > 0:
r_decoder_out = torch.nn.functional.log_softmax(
r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out.view(B2, T2, V)
index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
r_score = r_decoder_out.gather(2, index).squeeze(2)
r_score = r_score * mask
score = score * (1 - self.reverse_weight) + \
self.reverse_weight * r_score
r_decoder_out = r_decoder_out.view(B, bz, T2, V)
score = torch.sum(score, axis=1) # B2
score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
best_index = torch.argmax(score, dim=1)
if self.decoder_fastertransformer:
return decoder_out, best_index
else:
return best_index
def to_numpy(tensors):
out = []
if type(tensors) == torch.tensor:
tensors = [tensors]
for tensor in tensors:
if tensor.requires_grad:
tensor = tensor.detach().cpu().numpy()
else:
tensor = tensor.cpu().numpy()
out.append(tensor)
return out
def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
for a, b in zip(xlist, blist):
try:
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
except AssertionError as error:
if tolerate_small_mismatch:
print(error)
else:
raise
def export_offline_encoder(model, configs, args, logger, encoder_onnx_path):
bz = 32
seq_len = 100
beam_size = args.beam_size
feature_size = configs["input_dim"]
speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
speech_lens = torch.randint(
low=10, high=seq_len, size=(bz,), dtype=torch.int32)
encoder = Encoder(model.encoder, model.ctc, beam_size)
encoder.eval()
torch.onnx.export(encoder,
(speech, speech_lens),
encoder_onnx_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=['speech', 'speech_lengths'],
output_names=['encoder_out', 'encoder_out_lens',
'ctc_log_probs',
'beam_log_probs', 'beam_log_probs_idx'],
dynamic_axes={
'speech': {0: 'B', 1: 'T'},
'speech_lengths': {0: 'B'},
'encoder_out': {0: 'B', 1: 'T_OUT'},
'encoder_out_lens': {0: 'B'},
'ctc_log_probs': {0: 'B', 1: 'T_OUT'},
'beam_log_probs': {0: 'B', 1: 'T_OUT'},
'beam_log_probs_idx': {0: 'B', 1: 'T_OUT'},
},
verbose=False
)
with torch.no_grad():
o0, o1, o2, o3, o4 = encoder(speech, speech_lens)
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(encoder_onnx_path,
providers=providers)
ort_inputs = {'speech': to_numpy(speech),
'speech_lengths': to_numpy(speech_lens)}
ort_outs = ort_session.run(None, ort_inputs)
# check encoder output
test(to_numpy([o0, o1, o2, o3, o4]), ort_outs)
logger.info("export offline onnx encoder succeed!")
onnx_config = {"beam_size": args.beam_size,
"reverse_weight": args.reverse_weight,
"ctc_weight": args.ctc_weight,
"fp16": args.fp16}
return onnx_config
def export_online_encoder(model, configs, args, logger, encoder_onnx_path):
decoding_chunk_size = args.decoding_chunk_size
subsampling = model.encoder.embed.subsampling_rate
context = model.encoder.embed.right_context + 1
decoding_window = (decoding_chunk_size - 1) * subsampling + context
batch_size = 32
audio_len = decoding_window
feature_size = configs["input_dim"]
output_size = configs["encoder_conf"]["output_size"]
num_layers = configs["encoder_conf"]["num_blocks"]
# in transformer the cnn module will not be available
transformer = False
cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
if not cnn_module_kernel:
transformer = True
num_decoding_left_chunks = args.num_decoding_left_chunks
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
if configs['encoder'] == 'squeezeformer':
encoder = StreamingSqueezeformerEncoder(
model, required_cache_size, args.beam_size)
else:
encoder = StreamingEncoder(
model, required_cache_size, args.beam_size, transformer)
encoder.eval()
# begin to export encoder
chunk_xs = torch.randn(batch_size, audio_len,
feature_size, dtype=torch.float32)
chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len
offset = torch.arange(0, batch_size).unsqueeze(1)
# (elayers, b, head, cache_t1, d_k * 2)
head = configs["encoder_conf"]["attention_heads"]
d_k = configs["encoder_conf"]["output_size"] // head
att_cache = torch.randn(batch_size, num_layers, head,
required_cache_size, d_k * 2,
dtype=torch.float32)
cnn_cache = torch.randn(batch_size, num_layers, output_size,
cnn_module_kernel, dtype=torch.float32)
cache_mask = torch.ones(
batch_size, 1, required_cache_size, dtype=torch.float32)
input_names = ['chunk_xs', 'chunk_lens', 'offset',
'att_cache', 'cnn_cache', 'cache_mask']
output_names = ['log_probs', 'log_probs_idx', 'chunk_out',
'chunk_out_lens', 'r_offset', 'r_att_cache',
'r_cnn_cache', 'r_cache_mask']
input_tensors = (chunk_xs, chunk_lens, offset,
att_cache, cnn_cache, cache_mask)
if transformer:
output_names.pop(6)
all_names = input_names + output_names
dynamic_axes = {}
for name in all_names:
# only the first dimension is dynamic
# all other dimension is fixed
dynamic_axes[name] = {0: 'B'}
torch.onnx.export(encoder,
input_tensors,
encoder_onnx_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
verbose=False)
with torch.no_grad():
torch_outs = encoder(chunk_xs, chunk_lens, offset,
att_cache, cnn_cache, cache_mask)
if transformer:
torch_outs = list(torch_outs).pop(6)
ort_session = onnxruntime.InferenceSession(encoder_onnx_path,
providers=["CUDAExecutionProvider"])
ort_inputs = {}
input_tensors = to_numpy(input_tensors)
for idx, name in enumerate(input_names):
ort_inputs[name] = input_tensors[idx]
if transformer:
del ort_inputs['cnn_cache']
ort_outs = ort_session.run(None, ort_inputs)
test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05)
logger.info("export to onnx streaming encoder succeed!")
onnx_config = {
"subsampling_rate": subsampling,
"context": context,
"decoding_chunk_size": decoding_chunk_size,
"num_decoding_left_chunks": num_decoding_left_chunks,
"beam_size": args.beam_size,
"fp16": args.fp16,
"feat_size": feature_size,
"decoding_window": decoding_window,
"cnn_module_kernel_cache": cnn_module_kernel
}
return onnx_config
def export_rescoring_decoder(model, configs, args,
logger, decoder_onnx_path, decoder_fastertransformer):
bz, seq_len = 32, 100
beam_size = args.beam_size
decoder = Decoder(model.decoder,
model.ctc_weight,
model.reverse_weight,
beam_size,
decoder_fastertransformer)
decoder.eval()
hyps_pad_sos_eos = torch.randint(
low=3, high=1000, size=(bz, beam_size, seq_len))
hyps_lens_sos = torch.randint(low=3, high=seq_len, size=(bz, beam_size),
dtype=torch.int32)
r_hyps_pad_sos_eos = torch.randint(
low=3, high=1000, size=(bz, beam_size, seq_len))
output_size = configs["encoder_conf"]["output_size"]
encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32)
encoder_out_lens = torch.randint(
low=3, high=seq_len, size=(bz,), dtype=torch.int32)
ctc_score = torch.randn(bz, beam_size, dtype=torch.float32)
input_names = ['encoder_out', 'encoder_out_lens',
'hyps_pad_sos_eos', 'hyps_lens_sos',
'r_hyps_pad_sos_eos', 'ctc_score']
output_names = ['best_index']
if decoder_fastertransformer:
output_names.insert(0, 'decoder_out')
torch.onnx.export(decoder,
(encoder_out, encoder_out_lens,
hyps_pad_sos_eos, hyps_lens_sos,
r_hyps_pad_sos_eos, ctc_score),
decoder_onnx_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes={'encoder_out': {0: 'B', 1: 'T'},
'encoder_out_lens': {0: 'B'},
'hyps_pad_sos_eos': {0: 'B', 2: 'T2'},
'hyps_lens_sos': {0: 'B'},
'r_hyps_pad_sos_eos': {0: 'B', 2: 'T2'},
'ctc_score': {0: 'B'},
'best_index': {0: 'B'},
},
verbose=False
)
with torch.no_grad():
o0 = decoder(encoder_out,
encoder_out_lens,
hyps_pad_sos_eos,
hyps_lens_sos,
r_hyps_pad_sos_eos,
ctc_score)
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(decoder_onnx_path,
providers=providers)
input_tensors = [encoder_out, encoder_out_lens, hyps_pad_sos_eos,
hyps_lens_sos, r_hyps_pad_sos_eos, ctc_score]
ort_inputs = {}
input_tensors = to_numpy(input_tensors)
for idx, name in enumerate(input_names):
ort_inputs[name] = input_tensors[idx]
# if model.reverse weight == 0,
# the r_hyps_pad will be removed
# from the onnx decoder since it doen't play any role
if model.reverse_weight == 0:
del ort_inputs['r_hyps_pad_sos_eos']
ort_outs = ort_session.run(None, ort_inputs)
# check decoder output
if decoder_fastertransformer:
test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05)
else:
test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05)
logger.info("export to onnx decoder succeed!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='export x86_gpu model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--cmvn_file', required=False, default='', type=str,
help='global_cmvn file, default path is in config file')
parser.add_argument('--reverse_weight', default=-1.0, type=float,
required=False,
help='reverse weight for bitransformer,' +
'default value is in config file')
parser.add_argument('--ctc_weight', default=-1.0, type=float,
required=False,
help='ctc weight, default value is in config file')
parser.add_argument('--beam_size', default=10, type=int, required=False,
help="beam size would be ctc output size")
parser.add_argument('--output_onnx_dir',
default="onnx_model",
help='output onnx encoder and decoder directory')
parser.add_argument('--fp16',
action='store_true',
help='whether to export fp16 model, default false')
# arguments for streaming encoder
parser.add_argument('--streaming',
action='store_true',
help="whether to export streaming encoder, default false")
parser.add_argument('--decoding_chunk_size',
default=16,
type=int,
required=False,
help='the decoding chunk size, <=0 is not supported')
parser.add_argument('--num_decoding_left_chunks',
default=5,
type=int,
required=False,
help="number of left chunks, <= 0 is not supported")
parser.add_argument('--decoder_fastertransformer',
action='store_true',
help='return decoder_out and best_index for ft')
args = parser.parse_args()
torch.manual_seed(0)
torch.set_printoptions(precision=10)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if args.cmvn_file and os.path.exists(args.cmvn_file):
configs['cmvn_file'] = args.cmvn_file
if args.reverse_weight != -1.0 and 'reverse_weight' in configs['model_conf']:
configs['model_conf']['reverse_weight'] = args.reverse_weight
print("Update reverse weight to", args.reverse_weight)
if args.ctc_weight != -1:
print("Update ctc weight to ", args.ctc_weight)
configs['model_conf']['ctc_weight'] = args.ctc_weight
configs["encoder_conf"]["use_dynamic_chunk"] = False
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
model.eval()
if not os.path.exists(args.output_onnx_dir):
os.mkdir(args.output_onnx_dir)
encoder_onnx_path = os.path.join(args.output_onnx_dir, 'encoder.onnx')
export_enc_func = None
if args.streaming:
assert args.decoding_chunk_size > 0
assert args.num_decoding_left_chunks > 0
export_enc_func = export_online_encoder
else:
export_enc_func = export_offline_encoder
onnx_config = export_enc_func(
model, configs, args, logger, encoder_onnx_path)
decoder_onnx_path = os.path.join(args.output_onnx_dir, 'decoder.onnx')
export_rescoring_decoder(model, configs, args, logger,
decoder_onnx_path, args.decoder_fastertransformer)
if args.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import convert_float_to_float16
except ImportError:
print('Please install onnxmltools!')
sys.exit(1)
encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path)
encoder_onnx_model = convert_float_to_float16(encoder_onnx_model)
encoder_onnx_path = os.path.join(
args.output_onnx_dir, 'encoder_fp16.onnx')
onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path)
decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path)
decoder_onnx_model = convert_float_to_float16(decoder_onnx_model)
decoder_onnx_path = os.path.join(
args.output_onnx_dir, 'decoder_fp16.onnx')
onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path)
# dump configurations
config_dir = os.path.join(args.output_onnx_dir, "config.yaml")
with open(config_dir, "w") as out:
yaml.dump(onnx_config, out)
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--data_type',
default='raw',
choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument("--non_lang_syms",
help="non-linguistic symbol file. One symbol per line.")
parser.add_argument('--beam_size',
type=int,
default=10,
help='beam size for search')
parser.add_argument('--penalty',
type=float,
default=0.0,
help='length penalty')
parser.add_argument('--result_file', required=True, help='asr result file')
parser.add_argument('--batch_size',
type=int,
default=16,
help='asr result file')
parser.add_argument('--mode',
choices=[
'attention', 'ctc_greedy_search',
'ctc_prefix_beam_search', 'attention_rescoring',
'rnnt_greedy_search', 'rnnt_beam_search',
'rnnt_beam_attn_rescoring', 'ctc_beam_td_attn_rescoring',
'hlg_onebest', 'hlg_rescore'
],
default='attention',
help='decoding mode')
parser.add_argument('--search_ctc_weight',
type=float,
default=1.0,
help='ctc weight for nbest generation')
parser.add_argument('--search_transducer_weight',
type=float,
default=0.0,
help='transducer weight for nbest generation')
parser.add_argument('--ctc_weight',
type=float,
default=0.0,
help='ctc weight for rescoring weight in \
attention rescoring decode mode \
ctc weight for rescoring weight in \
transducer attention rescore decode mode')
parser.add_argument('--transducer_weight',
type=float,
default=0.0,
help='transducer weight for rescoring weight in transducer \
attention rescore mode')
parser.add_argument('--attn_weight',
type=float,
default=0.0,
help='attention weight for rescoring weight in transducer \
attention rescore mode')
parser.add_argument('--decoding_chunk_size',
type=int,
default=-1,
help='''decoding chunk size,
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here''')
parser.add_argument('--num_decoding_left_chunks',
type=int,
default=-1,
help='number of left chunks for decoding')
parser.add_argument('--simulate_streaming',
action='store_true',
help='simulate streaming inference')
parser.add_argument('--reverse_weight',
type=float,
default=0.0,
help='''right to left weight for attention rescoring
decode mode''')
parser.add_argument('--bpe_model',
default=None,
type=str,
help='bpe model for english part')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument('--connect_symbol',
default='',
type=str,
help='used to connect the output characters')
parser.add_argument('--word',
default='',
type=str,
help='word file, only used for hlg decode')
parser.add_argument('--hlg',
default='',
type=str,
help='hlg file, only used for hlg decode')
parser.add_argument('--lm_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument('--decoder_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument('--r_decoder_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['HIP_VISIBLE_DEVICES'] = str(args.gpu)
if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring'
] and args.batch_size > 1:
logging.fatal(
'decoding mode {} must be running with batch_size == 1'.format(
args.mode))
sys.exit(1)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
symbol_table = read_symbol_table(args.dict)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 0
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_sub'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = False
test_conf['sort'] = False
if 'fbank_conf' in test_conf:
test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in test_conf:
test_conf['mfcc_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size
non_lang_syms = read_non_lang_symbols(args.non_lang_syms)
test_dataset = Dataset(args.data_type,
args.test_data,
symbol_table,
test_conf,
args.bpe_model,
non_lang_syms,
partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=8, pin_memory=True)
# Init asr model from configs
model = init_model(configs)
#print('############################')
# Load dict
char_dict = {v: k for k, v in symbol_table.items()}
eos = len(char_dict) - 1
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
#print('model to device############')
model.eval()
with torch.no_grad(), open(args.result_file, 'w') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
target_lengths = target_lengths.to(device)
if args.mode == 'attention':
hyps, _ = model.recognize(
feats,
feats_lengths,
beam_size=args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
hyps = [hyp.tolist() for hyp in hyps]
elif args.mode == 'ctc_greedy_search':
hyps, _ = model.ctc_greedy_search(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
elif args.mode == 'rnnt_greedy_search':
assert (feats.size(0) == 1)
assert 'predictor' in configs
hyps = model.greedy_search(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
elif args.mode == 'rnnt_beam_search':
assert (feats.size(0) == 1)
assert 'predictor' in configs
hyps = model.beam_search(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
beam_size=args.beam_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming,
ctc_weight=args.search_ctc_weight,
transducer_weight=args.search_transducer_weight)
elif args.mode == 'rnnt_beam_attn_rescoring':
assert (feats.size(0) == 1)
assert 'predictor' in configs
hyps = model.transducer_attention_rescoring(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
beam_size=args.beam_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming,
ctc_weight=args.ctc_weight,
transducer_weight=args.transducer_weight,
attn_weight=args.attn_weight,
reverse_weight=args.reverse_weight,
search_ctc_weight=args.search_ctc_weight,
search_transducer_weight=args.search_transducer_weight)
elif args.mode == 'ctc_beam_td_attn_rescoring':
assert (feats.size(0) == 1)
assert 'predictor' in configs
hyps = model.transducer_attention_rescoring(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
beam_size=args.beam_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming,
ctc_weight=args.ctc_weight,
transducer_weight=args.transducer_weight,
attn_weight=args.attn_weight,
reverse_weight=args.reverse_weight,
search_ctc_weight=args.search_ctc_weight,
search_transducer_weight=args.search_transducer_weight,
beam_search_type='ctc')
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif args.mode == 'ctc_prefix_beam_search':
assert (feats.size(0) == 1)
hyp, _ = model.ctc_prefix_beam_search(
feats,
feats_lengths,
args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
hyps = [hyp]
elif args.mode == 'attention_rescoring':
#print('11111111111 attention_resoring 1111111111111111')
assert (feats.size(0) == 1)
hyp, source = model.attention_rescoring(
feats,
feats_lengths,
args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
ctc_weight=args.ctc_weight,
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight)
hyps = [hyp]
#print(hyps)
#print(source)
elif args.mode == 'hlg_onebest':
hyps = model.hlg_onebest(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming,
hlg=args.hlg,
word=args.word,
symbol_table=symbol_table)
elif args.mode == 'hlg_rescore':
hyps = model.hlg_rescore(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming,
lm_scale=args.lm_scale,
decoder_scale=args.decoder_scale,
r_decoder_scale=args.r_decoder_scale,
hlg=args.hlg,
word=args.word,
symbol_table=symbol_table)
for i, key in enumerate(keys):
content = []
for w in hyps[i]:
if w == eos:
break
content.append(char_dict[w])
logging.info('{} {}'.format(key, args.connect_symbol.join(content)))
fout.write('{} {}\n'.format(key, args.connect_symbol.join(content)))
if __name__ == '__main__':
main()
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for testing exported onnx encoder and decoder from
export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference.
It requires a python wrapped c++ ctc decoder.
Please install it by following:
https://github.com/Slyne/ctc_decoder.git
"""
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.utils.common import IGNORE_ID
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.config import override_config
import onnxruntime as rt
import multiprocessing
import numpy as np
try:
from swig_decoders import map_batch, \
ctc_beam_search_decoder_batch, \
TrieVector, PathTrie
except ImportError:
print('Please install ctc decoders first by refering to\n' +
'https://github.com/Slyne/ctc_decoder.git')
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--data_type',
default='raw',
choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument('--encoder_onnx', required=True, help='encoder onnx file')
parser.add_argument('--decoder_onnx', required=True, help='decoder onnx file')
parser.add_argument('--result_file', required=True, help='asr result file')
parser.add_argument('--batch_size',
type=int,
default=32,
help='asr result file')
parser.add_argument('--mode',
choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search',
'attention_rescoring'],
default='attention_rescoring',
help='decoding mode')
parser.add_argument('--bpe_model',
default=None,
type=str,
help='bpe model for english part')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument('--fp16',
action='store_true',
help='whether to export fp16 model, default false')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
symbol_table = read_symbol_table(args.dict)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 0
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = False
test_conf['sort'] = False
test_conf['fbank_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size
test_dataset = Dataset(args.data_type,
args.test_data,
symbol_table,
test_conf,
args.bpe_model,
partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
# Init asr model from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
if use_cuda:
EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
EP_list = ['CPUExecutionProvider']
encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list)
decoder_ort_session = None
if args.mode == "attention_rescoring":
decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list)
# Load dict
vocabulary = []
char_dict = {}
with open(args.dict, 'r') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
eos = sos = len(char_dict) - 1
with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
keys, feats, _, feats_lengths, _ = batch
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
if args.fp16:
feats = feats.astype(np.float16)
ort_inputs = {
encoder_ort_session.get_inputs()[0].name: feats,
encoder_ort_session.get_inputs()[1].name: feats_lengths}
ort_outs = encoder_ort_session.run(None, ort_inputs)
encoder_out, encoder_out_lens, ctc_log_probs, \
beam_log_probs, beam_log_probs_idx = ort_outs
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
num_processes = min(multiprocessing.cpu_count(), batch_size)
if args.mode == 'ctc_greedy_search':
if beam_size != 1:
log_probs_idx = beam_log_probs_idx[:, :, 0]
batch_sents = []
for idx, seq in enumerate(log_probs_idx):
batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
hyps = map_batch(batch_sents, vocabulary, num_processes,
True, 0)
elif args.mode in ('ctc_prefix_beam_search', "attention_rescoring"):
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = [] # only effective in streaming deployment
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
batch_log_probs_seq.append(
batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(
batch_log_probs_idx_list[i][0:num_sent])
root_dict[i] = PathTrie()
batch_root.append(root_dict[i])
batch_start.append(True)
score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq,
batch_log_probs_ids,
batch_root,
batch_start,
beam_size,
num_processes,
0, -2, 0.99999)
if args.mode == 'ctc_prefix_beam_search':
hyps = []
for cand_hyps in score_hyps:
hyps.append(cand_hyps[0][1])
hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
if args.mode == 'attention_rescoring':
ctc_score, all_hyps = [], []
max_len = 0
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < beam_size:
hyps += (beam_size - cur_len) * [(-float("INF"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
if len(hyp[1]) > max_len:
max_len = len(hyp[1])
ctc_score.append(cur_ctc_score)
if args.fp16:
ctc_score = np.array(ctc_score, dtype=np.float16)
else:
ctc_score = np.array(ctc_score, dtype=np.float32)
hyps_pad_sos_eos = np.ones(
(batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID
r_hyps_pad_sos_eos = np.ones(
(batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID
hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32)
k = 0
for i in range(batch_size):
for j in range(beam_size):
cand = all_hyps[k]
l = len(cand) + 2
hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos]
r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos]
hyps_lens_sos[i][j] = len(cand) + 1
k += 1
decoder_ort_inputs = {
decoder_ort_session.get_inputs()[0].name: encoder_out,
decoder_ort_session.get_inputs()[1].name: encoder_out_lens,
decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos,
decoder_ort_session.get_inputs()[3].name: hyps_lens_sos,
decoder_ort_session.get_inputs()[-1].name: ctc_score}
if reverse_weight > 0:
r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name
decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos
best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0]
best_sents = []
k = 0
for idx in best_index:
cur_best_sent = all_hyps[k: k + beam_size][idx]
best_sents.append(cur_best_sent)
k += beam_size
hyps = map_batch(best_sents, vocabulary, num_processes)
for i, key in enumerate(keys):
content = hyps[i]
logging.info('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))
if __name__ == '__main__':
main()
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import time
import torch
import torch.distributed as dist
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import (load_checkpoint, save_checkpoint,
load_trained_modules)
from wenet.utils.executor import Executor
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model
from wenet.utils.global_vars import get_global_steps, get_num_trained_samples
from wenet.utils.compute_acc import compute_char_acc
def write_pid_file(pid_file_path):
'''Write pid file for watching the process later.
In each round of case, we will write the current pid in the same path.
'''
if os.path.exists(pid_file_path):
os.remove(pid_file_path)
file_d=open(pid_file_path,"w")
file_d.write("%s\n" % os.getpid())
file_d.close()
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--data_type',
default='raw',
choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this local rank, -1 for cpu')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.rank',
dest='rank',
default=0,
type=int,
help='global rank for distributed training')
parser.add_argument('--ddp.world_size',
dest='world_size',
default=-1,
type=int,
help='''number of total processes/gpus for
distributed training''')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--ddp.init_method',
dest='init_method',
default=None,
help='ddp init method')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--use_amp',
action='store_true',
default=False,
help='Use automatic mixed precision training')
parser.add_argument('--fp16_grad_sync',
action='store_true',
default=False,
help='Use fp16 gradient sync for ddp')
parser.add_argument('--cmvn', default=None, help='global cmvn file')
parser.add_argument('--symbol_table',
required=True,
help='model unit symbol table for training')
parser.add_argument("--non_lang_syms",
help="non-linguistic symbol file. One symbol per line.")
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--bpe_model',
default=None,
type=str,
help='bpe model for english part')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument("--enc_init",
default=None,
type=str,
help="Pre-trained model to initialize encoder")
parser.add_argument("--enc_init_mods",
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
to initialize ,separated by a comma")
parser.add_argument('--val_ref_file',
dest='val_ref_file',
default='data/test/text',
help='validation ref file')
parser.add_argument('--val_hyp_file',
dest='val_hyp_file',
default='exp/conformer/test_attention_rescoring/text',
help='validation hyp file')
parser.add_argument('--log_dir',
type=str,
default='/data/flagperf/training/result/',
help='Log directory in container.')
args = parser.parse_args()
return args
def main():
args = get_args()
if args.rank == 0:
write_pid_file(args.log_dir)
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Set random seed
torch.manual_seed(777)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
distributed = args.world_size > 1
if distributed:
logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
dist.init_process_group(args.dist_backend,
init_method=args.init_method,
world_size=args.world_size,
rank=args.rank)
symbol_table = read_symbol_table(args.symbol_table)
train_conf = configs['dataset_conf']
cv_conf = copy.deepcopy(train_conf)
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False
cv_conf['spec_sub'] = False
cv_conf['spec_trim'] = False
cv_conf['shuffle'] = False
non_lang_syms = read_non_lang_symbols(args.non_lang_syms)
train_dataset = Dataset(args.data_type, args.train_data, symbol_table,
train_conf, args.bpe_model, non_lang_syms, True)
cv_dataset = Dataset(args.data_type,
args.cv_data,
symbol_table,
cv_conf,
args.bpe_model,
non_lang_syms,
partition=False)
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
if 'fbank_conf' in configs['dataset_conf']:
input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
else:
input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins']
vocab_size = len(symbol_table)
# Save configs to model_dir/train.yaml for inference and export
configs['input_dim'] = input_dim
configs['output_dim'] = vocab_size
configs['cmvn_file'] = args.cmvn
configs['is_json_cmvn'] = True
if args.rank == 0:
saved_config_path = os.path.join(args.model_dir, 'train.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs)
fout.write(data)
# Init asr model from configs
model = init_model(configs)
print(model)
num_params = sum(p.numel() for p in model.parameters())
print('the number of model params: {:,d}'.format(num_params))
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
if args.rank == 0:
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
infos = load_checkpoint(model, args.checkpoint)
elif args.enc_init is not None:
logging.info('load pretrained encoders: {}'.format(args.enc_init))
infos = load_trained_modules(model, args)
else:
infos = {}
start_epoch = infos.get('epoch', -1) + 1
cv_loss = infos.get('cv_loss', 0.0)
step = infos.get('step', -1)
num_epochs = configs.get('max_epoch', 100)
model_dir = args.model_dir
writer = None
if args.rank == 0:
os.makedirs(model_dir, exist_ok=True)
exp_id = os.path.basename(model_dir)
#writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
if distributed:
assert (torch.cuda.is_available())
# cuda model is required for nn.parallel.DistributedDataParallel
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=False)
device = torch.device("cuda")
if args.fp16_grad_sync:
from torch.distributed.algorithms.ddp_comm_hooks import (
default as comm_hooks,
)
model.register_comm_hook(
state=None, hook=comm_hooks.fp16_compress_hook
)
else:
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
if configs['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
elif configs['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), **configs['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['optim'])
if configs['scheduler'] == 'warmuplr':
scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
elif configs['scheduler'] == 'NoamHoldAnnealing':
scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf'])
else:
raise ValueError("unknown scheduler: " + configs['scheduler'])
final_epoch = None
target_acc = 93.0
final_acc = 0
training_only = 0
configs['rank'] = args.rank
configs['is_distributed'] = distributed
configs['use_amp'] = args.use_amp
if start_epoch == 0 and args.rank == 0:
save_model_path = os.path.join(model_dir, 'init.pt')
save_checkpoint(model, save_model_path)
# Start training loop
executor.step = step
scheduler.set_step(step)
# used for pytorch amp mixed precision training
scaler = None
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
training_start = time.time()
for epoch in range(start_epoch, num_epochs):
start = time.time()
train_dataset.set_epoch(epoch)
configs['epoch'] = epoch
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, scheduler, train_data_loader, device,
writer, configs, scaler)
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device,
configs)
cv_loss = total_loss / num_seen_utts
epoch_time = time.time() - start
training_only += epoch_time
dist.barrier()
#logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
save_checkpoint(
model, save_model_path, {
'epoch': epoch,
'lr': lr,
'cv_loss': cv_loss,
'step': executor.step
})
#writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
#writer.add_scalar('epoch/lr', lr, epoch)
final_epoch = epoch
char_acc = 0.0
# Run validation by calling run.sh stage=5
# Only run in rank 0
if args.rank == 0:
start = time.time()
if final_epoch is not None:
final_model_path = os.path.join(model_dir, 'final.pt')
os.remove(final_model_path) if os.path.exists(final_model_path) else None
os.symlink('{}.pt'.format(final_epoch), final_model_path)
val_cmd = os.path.join(os.getcwd(), "validate.sh")
logging.info(f'rank {args.rank}: ' + "Start validation")
os.system(val_cmd)
time.sleep(0.5)
char_acc = compute_char_acc(args)
logging.info(f'rank {args.rank}: ' + "Finish validation")
eval_time = time.time() - start
global_steps = get_global_steps()
eval_output = f'[PerfLog] {{"event": "EVALUATE_END", "value": {{"global_steps": {global_steps},"eval_mlm_accuracy":{char_acc:.4f},"eval_time": {eval_time:.2f},"epoch_time":{epoch_time:.2f}}}}}'
logging.info(f'rank {args.rank}: ' + eval_output)
dist.barrier()
torch.cuda.synchronize()
t = torch.tensor([char_acc], device='cuda')
dist.broadcast(t, 0)
char_acc = t[0].item()
if char_acc >= target_acc:
final_acc = char_acc
break
train_time = time.time() - training_start
num_trained_samples = get_num_trained_samples()
samples_sec = num_trained_samples / training_only
train_output = f'[PerfLog] {{"event": "TRAIN_END", "value": {{"accuracy":{final_acc:.4f},"train_time":{train_time:.2f},"samples/sec: {samples_sec:.2f}","num_trained_samples":{num_trained_samples}}}}}'
logging.info(f'rank {args.rank}: ' + train_output)
if __name__ == '__main__':
main()
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
import wenet.dataset.processor as processor
from wenet.utils.file_utils import read_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
# TODO(Binbin Zhang): fix this
# We can not handle uneven data for CV on DDP, so we don't
# sample data by rank, that means every GPU gets the same
# and all the CV data
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
# yield dict(src=src)
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def Dataset(data_type,
data_list_file,
symbol_table,
conf,
bpe_model=None,
non_lang_syms=None,
partition=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
bpe_model(str): model for english bpe part
partition(bool): whether to do data partition in terms of rank
"""
assert data_type in ['raw', 'shard']
lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
if data_type == 'shard':
dataset = Processor(dataset, processor.url_opener)
dataset = Processor(dataset, processor.tar_file_and_group)
else:
dataset = Processor(dataset, processor.parse_raw)
dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model,
non_lang_syms, conf.get('split_with_space', False))
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)
resample_conf = conf.get('resample_conf', {})
dataset = Processor(dataset, processor.resample, **resample_conf)
speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = Processor(dataset, processor.speed_perturb)
feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc']
if feats_type == 'fbank':
fbank_conf = conf.get('fbank_conf', {})
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
elif feats_type == 'mfcc':
mfcc_conf = conf.get('mfcc_conf', {})
dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
spec_trim = conf.get('spec_trim', False)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
if spec_sub:
spec_sub_conf = conf.get('spec_sub_conf', {})
dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf)
if spec_trim:
spec_trim_conf = conf.get('spec_trim_conf', {})
dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
sort = conf.get('sort', True)
if sort:
sort_conf = conf.get('sort_conf', {})
dataset = Processor(dataset, processor.sort, **sort_conf)
batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)
dataset = Processor(dataset, processor.padding)
return dataset
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2014-2016 Brno University of Technology (author: Karel Vesely)
# Licensed under the Apache License, Version 2.0 (the "License")
import numpy as np
import sys, os, re, gzip, struct
#################################################
# Adding kaldi tools to shell path,
# Select kaldi,
if not 'KALDI_ROOT' in os.environ:
# Default! To change run python with 'export KALDI_ROOT=/some_dir python'
os.environ['KALDI_ROOT']='/mnt/matylda5/iveselyk/Tools/kaldi-trunk'
# Add kaldi tools to path,
os.environ['PATH'] = os.popen('echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/').readline().strip() + ':' + os.environ['PATH']
#################################################
# Define all custom exceptions,
class UnsupportedDataType(Exception): pass
class UnknownVectorHeader(Exception): pass
class UnknownMatrixHeader(Exception): pass
class BadSampleSize(Exception): pass
class BadInputFormat(Exception): pass
class SubprocessFailed(Exception): pass
#################################################
# Data-type independent helper functions,
def open_or_fd(file, mode='rb'):
""" fd = open_or_fd(file)
Open file, gzipped file, pipe, or forward the file-descriptor.
Eventually seeks in the 'file' argument contains ':offset' suffix.
"""
offset = None
try:
# strip 'ark:' prefix from r{x,w}filename (optional),
if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', file):
(prefix,file) = file.split(':',1)
# separate offset from filename (optional),
if re.search(':[0-9]+$', file):
(file,offset) = file.rsplit(':',1)
# input pipe?
if file[-1] == '|':
fd = popen(file[:-1], 'rb') # custom,
# output pipe?
elif file[0] == '|':
fd = popen(file[1:], 'wb') # custom,
# is it gzipped?
elif file.split('.')[-1] == 'gz':
fd = gzip.open(file, mode)
# a normal file...
else:
fd = open(file, mode)
except TypeError:
# 'file' is opened file descriptor,
fd = file
# Eventually seek to offset,
if offset != None: fd.seek(int(offset))
return fd
# based on '/usr/local/lib/python3.4/os.py'
def popen(cmd, mode="rb"):
if not isinstance(cmd, str):
raise TypeError("invalid cmd type (%s, expected string)" % type(cmd))
import subprocess, io, threading
# cleanup function for subprocesses,
def cleanup(proc, cmd):
ret = proc.wait()
if ret > 0:
raise SubprocessFailed('cmd %s returned %d !' % (cmd,ret))
return
# text-mode,
if mode == "r":
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread,
return io.TextIOWrapper(proc.stdout)
elif mode == "w":
proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE)
threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread,
return io.TextIOWrapper(proc.stdin)
# binary,
elif mode == "rb":
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread,
return proc.stdout
elif mode == "wb":
proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE)
threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread,
return proc.stdin
# sanity,
else:
raise ValueError("invalid mode %s" % mode)
def read_key(fd):
""" [key] = read_key(fd)
Read the utterance-key from the opened ark/stream descriptor 'fd'.
"""
key = ''
while 1:
char = fd.read(1).decode("latin1")
if char == '' : break
if char == ' ' : break
key += char
key = key.strip()
if key == '': return None # end of file,
assert(re.match('^\S+$',key) != None) # check format (no whitespace!)
return key
#################################################
# Integer vectors (alignments, ...),
def read_ali_ark(file_or_fd):
""" Alias to 'read_vec_int_ark()' """
return read_vec_int_ark(file_or_fd)
def read_vec_int_ark(file_or_fd):
""" generator(key,vec) = read_vec_int_ark(file_or_fd)
Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
Read ark to a 'dictionary':
d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) }
"""
fd = open_or_fd(file_or_fd)
try:
key = read_key(fd)
while key:
ali = read_vec_int(fd)
yield key, ali
key = read_key(fd)
finally:
if fd is not file_or_fd: fd.close()
def read_vec_int_scp(file_or_fd):
""" generator(key,vec) = read_vec_int_scp(file_or_fd)
Returns generator of (key,vector<int>) tuples, read according to kaldi scp.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the scp:
for key,vec in kaldi_io.read_vec_int_scp(file):
...
Read scp to a 'dictionary':
d = { key:vec for key,mat in kaldi_io.read_vec_int_scp(file) }
"""
fd = open_or_fd(file_or_fd)
try:
for line in fd:
(key,rxfile) = line.decode().split(' ')
vec = read_vec_int(rxfile)
yield key, vec
finally:
if fd is not file_or_fd : fd.close()
def read_vec_int(file_or_fd):
""" [int-vec] = read_vec_int(file_or_fd)
Read kaldi integer vector, ascii or binary input,
"""
fd = open_or_fd(file_or_fd)
binary = fd.read(2).decode()
if binary == '\0B': # binary flag
assert(fd.read(1).decode() == '\4'); # int-size
vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim
# Elements from int32 vector are sored in tuples: (sizeof(int32), value),
vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size)
assert(vec[0]['size'] == 4) # int32 size,
ans = vec[:]['value'] # values are in 2nd column,
else: # ascii,
arr = (binary + fd.readline().decode()).strip().split()
try:
arr.remove('['); arr.remove(']') # optionally
except ValueError:
pass
ans = np.array(arr, dtype=int)
if fd is not file_or_fd : fd.close() # cleanup
return ans
# Writing,
def write_vec_int(file_or_fd, v, key=''):
""" write_vec_int(f, v, key='')
Write a binary kaldi integer vector to filename or stream.
Arguments:
file_or_fd : filename or opened file descriptor for writing,
v : the vector to be stored,
key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
Example of writing single vector:
kaldi_io.write_vec_int(filename, vec)
Example of writing arkfile:
with open(ark_file,'w') as f:
for key,vec in dict.iteritems():
kaldi_io.write_vec_flt(f, vec, key=key)
"""
fd = open_or_fd(file_or_fd, mode='wb')
if sys.version_info[0] == 3: assert(fd.mode == 'wb')
try:
if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id),
fd.write('\0B'.encode()) # we write binary!
# dim,
fd.write('\4'.encode()) # int32 type,
fd.write(struct.pack(np.dtype('int32').char, v.shape[0]))
# data,
for i in range(len(v)):
fd.write('\4'.encode()) # int32 type,
fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary,
finally:
if fd is not file_or_fd : fd.close()
#################################################
# Float vectors (confidences, ivectors, ...),
# Reading,
def read_vec_flt_scp(file_or_fd):
""" generator(key,mat) = read_vec_flt_scp(file_or_fd)
Returns generator of (key,vector) tuples, read according to kaldi scp.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the scp:
for key,vec in kaldi_io.read_vec_flt_scp(file):
...
Read scp to a 'dictionary':
d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
"""
fd = open_or_fd(file_or_fd)
try:
for line in fd:
(key,rxfile) = line.decode().split(' ')
vec = read_vec_flt(rxfile)
yield key, vec
finally:
if fd is not file_or_fd : fd.close()
def read_vec_flt_ark(file_or_fd):
""" generator(key,vec) = read_vec_flt_ark(file_or_fd)
Create generator of (key,vector<float>) tuples, reading from an ark file/stream.
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
Read ark to a 'dictionary':
d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) }
"""
fd = open_or_fd(file_or_fd)
try:
key = read_key(fd)
while key:
ali = read_vec_flt(fd)
yield key, ali
key = read_key(fd)
finally:
if fd is not file_or_fd: fd.close()
def read_vec_flt(file_or_fd):
""" [flt-vec] = read_vec_flt(file_or_fd)
Read kaldi float vector, ascii or binary input,
"""
fd = open_or_fd(file_or_fd)
binary = fd.read(2).decode()
if binary == '\0B': # binary flag
# Data type,
header = fd.read(3).decode()
if header == 'FV ': sample_size = 4 # floats
elif header == 'DV ': sample_size = 8 # doubles
else: raise UnknownVectorHeader("The header contained '%s'" % header)
assert(sample_size > 0)
# Dimension,
assert(fd.read(1).decode() == '\4'); # int-size
vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim
# Read whole vector,
buf = fd.read(vec_size * sample_size)
if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32')
elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64')
else : raise BadSampleSize
return ans
else: # ascii,
arr = (binary + fd.readline().decode()).strip().split()
try:
arr.remove('['); arr.remove(']') # optionally
except ValueError:
pass
ans = np.array(arr, dtype=float)
if fd is not file_or_fd : fd.close() # cleanup
return ans
# Writing,
def write_vec_flt(file_or_fd, v, key=''):
""" write_vec_flt(f, v, key='')
Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats.
Arguments:
file_or_fd : filename or opened file descriptor for writing,
v : the vector to be stored,
key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
Example of writing single vector:
kaldi_io.write_vec_flt(filename, vec)
Example of writing arkfile:
with open(ark_file,'w') as f:
for key,vec in dict.iteritems():
kaldi_io.write_vec_flt(f, vec, key=key)
"""
fd = open_or_fd(file_or_fd, mode='wb')
if sys.version_info[0] == 3: assert(fd.mode == 'wb')
try:
if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id),
fd.write('\0B'.encode()) # we write binary!
# Data-type,
if v.dtype == 'float32': fd.write('FV '.encode())
elif v.dtype == 'float64': fd.write('DV '.encode())
else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % v.dtype)
# Dim,
fd.write('\04'.encode())
fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim
# Data,
fd.write(v.tobytes())
finally:
if fd is not file_or_fd : fd.close()
#################################################
# Float matrices (features, transformations, ...),
# Reading,
def read_mat_scp(file_or_fd):
""" generator(key,mat) = read_mat_scp(file_or_fd)
Returns generator of (key,matrix) tuples, read according to kaldi scp.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the scp:
for key,mat in kaldi_io.read_mat_scp(file):
...
Read scp to a 'dictionary':
d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
"""
fd = open_or_fd(file_or_fd)
try:
for line in fd:
(key,rxfile) = line.decode().split(' ')
mat = read_mat(rxfile)
yield key, mat
finally:
if fd is not file_or_fd : fd.close()
def read_mat_ark(file_or_fd):
""" generator(key,mat) = read_mat_ark(file_or_fd)
Returns generator of (key,matrix) tuples, read from ark file/stream.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the ark:
for key,mat in kaldi_io.read_mat_ark(file):
...
Read ark to a 'dictionary':
d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) }
"""
fd = open_or_fd(file_or_fd)
try:
key = read_key(fd)
while key:
mat = read_mat(fd)
yield key, mat
key = read_key(fd)
finally:
if fd is not file_or_fd : fd.close()
def read_mat(file_or_fd):
""" [mat] = read_mat(file_or_fd)
Reads single kaldi matrix, supports ascii and binary.
file_or_fd : file, gzipped file, pipe or opened file descriptor.
"""
fd = open_or_fd(file_or_fd)
try:
binary = fd.read(2).decode()
if binary == '\0B' :
mat = _read_mat_binary(fd)
else:
assert(binary == ' [')
mat = _read_mat_ascii(fd)
finally:
if fd is not file_or_fd: fd.close()
return mat
def _read_mat_binary(fd):
# Data type
header = fd.read(3).decode()
# 'CM', 'CM2', 'CM3' are possible values,
if header.startswith('CM'): return _read_compressed_mat(fd, header)
elif header == 'FM ': sample_size = 4 # floats
elif header == 'DM ': sample_size = 8 # doubles
else: raise UnknownMatrixHeader("The header contained '%s'" % header)
assert(sample_size > 0)
# Dimensions
s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0]
# Read whole matrix
buf = fd.read(rows * cols * sample_size)
if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32')
elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64')
else : raise BadSampleSize
mat = np.reshape(vec,(rows,cols))
return mat
def _read_mat_ascii(fd):
rows = []
while 1:
line = fd.readline().decode()
if (len(line) == 0) : raise BadInputFormat # eof, should not happen!
if len(line.strip()) == 0 : continue # skip empty line
arr = line.strip().split()
if arr[-1] != ']':
rows.append(np.array(arr,dtype='float32')) # not last line
else:
rows.append(np.array(arr[:-1],dtype='float32')) # last line
mat = np.vstack(rows)
return mat
def _read_compressed_mat(fd, format):
""" Read a compressed matrix,
see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h
methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...),
"""
assert(format == 'CM ') # The formats CM2, CM3 are not supported...
# Format of header 'struct',
global_header = np.dtype([('minvalue','float32'),('range','float32'),('num_rows','int32'),('num_cols','int32')]) # member '.format' is not written,
per_col_header = np.dtype([('percentile_0','uint16'),('percentile_25','uint16'),('percentile_75','uint16'),('percentile_100','uint16')])
# Mapping for percentiles in col-headers,
def uint16_to_float(value, min, range):
return np.float32(min + range * 1.52590218966964e-05 * value)
# Mapping for matrix elements,
def uint8_to_float_v2(vec, p0, p25, p75, p100):
# Split the vector by masks,
mask_0_64 = (vec <= 64);
mask_193_255 = (vec > 192);
mask_65_192 = (~(mask_0_64 | mask_193_255));
# Sanity check (useful but slow...),
# assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255])))
# assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0)))
# Build the float vector,
ans = np.empty(len(vec), dtype='float32')
ans[mask_0_64] = p0 + (p25 - p0) / 64. * vec[mask_0_64]
ans[mask_65_192] = p25 + (p75 - p25) / 128. * (vec[mask_65_192] - 64)
ans[mask_193_255] = p75 + (p100 - p75) / 63. * (vec[mask_193_255] - 192)
return ans
# Read global header,
globmin, globrange, rows, cols = np.frombuffer(fd.read(16), dtype=global_header, count=1)[0]
# The data is structed as [Colheader, ... , Colheader, Data, Data , .... ]
# { cols }{ size }
col_headers = np.frombuffer(fd.read(cols*8), dtype=per_col_header, count=cols)
data = np.reshape(np.frombuffer(fd.read(cols*rows), dtype='uint8', count=cols*rows), newshape=(cols,rows)) # stored as col-major,
mat = np.empty((cols,rows), dtype='float32')
for i, col_header in enumerate(col_headers):
col_header_flt = [ uint16_to_float(percentile, globmin, globrange) for percentile in col_header ]
mat[i] = uint8_to_float_v2(data[i], *col_header_flt)
return mat.T # transpose! col-major -> row-major,
def write_ark_scp(key, mat, ark_fout, scp_out):
mat_offset = write_mat(ark_fout, mat, key)
scp_line = '{}\t{}:{}'.format(key, ark_fout.name, mat_offset)
scp_out.write(scp_line)
scp_out.write('\n')
# Writing,
def write_mat(file_or_fd, m, key=''):
""" write_mat(f, m, key='')
Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats.
Arguments:
file_or_fd : filename of opened file descriptor for writing,
m : the matrix to be stored,
key (optional) : used for writing ark-file, the utterance-id gets written before the matrix.
Example of writing single matrix:
kaldi_io.write_mat(filename, mat)
Example of writing arkfile:
with open(ark_file,'w') as f:
for key,mat in dict.iteritems():
kaldi_io.write_mat(f, mat, key=key)
"""
mat_offset = 0
fd = open_or_fd(file_or_fd, mode='wb')
if sys.version_info[0] == 3: assert(fd.mode == 'wb')
try:
if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id),
mat_offset = fd.tell()
fd.write('\0B'.encode()) # we write binary!
# Data-type,
if m.dtype == 'float32': fd.write('FM '.encode())
elif m.dtype == 'float64': fd.write('DM '.encode())
else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % m.dtype)
# Dims,
fd.write('\04'.encode())
fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows
fd.write('\04'.encode())
fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols
# Data,
fd.write(m.tobytes())
finally:
if fd is not file_or_fd : fd.close()
return mat_offset
#################################################
# 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...)
# Corresponds to: vector<vector<tuple<int,float> > >
# - outer vector: time axis
# - inner vector: records at the time
# - tuple: int = index, float = value
#
def read_cnet_ark(file_or_fd):
""" Alias of function 'read_post_ark()', 'cnet' = confusion network """
return read_post_ark(file_or_fd)
def read_post_ark(file_or_fd):
""" generator(key,vec<vec<int,float>>) = read_post_ark(file)
Returns generator of (key,posterior) tuples, read from ark file.
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
Iterate the ark:
for key,post in kaldi_io.read_post_ark(file):
...
Read ark to a 'dictionary':
d = { key:post for key,post in kaldi_io.read_post_ark(file) }
"""
fd = open_or_fd(file_or_fd)
try:
key = read_key(fd)
while key:
post = read_post(fd)
yield key, post
key = read_key(fd)
finally:
if fd is not file_or_fd: fd.close()
def read_post(file_or_fd):
""" [post] = read_post(file_or_fd)
Reads single kaldi 'Posterior' in binary format.
The 'Posterior' is C++ type 'vector<vector<tuple<int,float> > >',
the outer-vector is usually time axis, inner-vector are the records
at given time, and the tuple is composed of an 'index' (integer)
and a 'float-value'. The 'float-value' can represent a probability
or any other numeric value.
Returns vector of vectors of tuples.
"""
fd = open_or_fd(file_or_fd)
ans=[]
binary = fd.read(2).decode(); assert(binary == '\0B'); # binary flag
assert(fd.read(1).decode() == '\4'); # int-size
outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins)
# Loop over 'outer-vector',
for i in range(outer_vec_size):
assert(fd.read(1).decode() == '\4'); # int-size
inner_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of records for frame (or bin)
data = np.frombuffer(fd.read(inner_vec_size*10), dtype=[('size_idx','int8'),('idx','int32'),('size_post','int8'),('post','float32')], count=inner_vec_size)
assert(data[0]['size_idx'] == 4)
assert(data[0]['size_post'] == 4)
ans.append(data[['idx','post']].tolist())
if fd is not file_or_fd: fd.close()
return ans
#################################################
# Kaldi Confusion Network bin begin/end times,
# (kaldi stores CNs time info separately from the Posterior).
#
def read_cntime_ark(file_or_fd):
""" generator(key,vec<tuple<float,float>>) = read_cntime_ark(file_or_fd)
Returns generator of (key,cntime) tuples, read from ark file.
file_or_fd : file, gzipped file, pipe or opened file descriptor.
Iterate the ark:
for key,time in kaldi_io.read_cntime_ark(file):
...
Read ark to a 'dictionary':
d = { key:time for key,time in kaldi_io.read_post_ark(file) }
"""
fd = open_or_fd(file_or_fd)
try:
key = read_key(fd)
while key:
cntime = read_cntime(fd)
yield key, cntime
key = read_key(fd)
finally:
if fd is not file_or_fd : fd.close()
def read_cntime(file_or_fd):
""" [cntime] = read_cntime(file_or_fd)
Reads single kaldi 'Confusion Network time info', in binary format:
C++ type: vector<tuple<float,float> >.
(begin/end times of bins at the confusion network).
Binary layout is '<num-bins> <beg1> <end1> <beg2> <end2> ...'
file_or_fd : file, gzipped file, pipe or opened file descriptor.
Returns vector of tuples.
"""
fd = open_or_fd(file_or_fd)
binary = fd.read(2).decode(); assert(binary == '\0B'); # assuming it's binary
assert(fd.read(1).decode() == '\4'); # int-size
vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins)
data = np.frombuffer(fd.read(vec_size*10), dtype=[('size_beg','int8'),('t_beg','float32'),('size_end','int8'),('t_end','float32')], count=vec_size)
assert(data[0]['size_beg'] == 4)
assert(data[0]['size_end'] == 4)
ans = data[['t_beg','t_end']].tolist() # Return vector of tuples (t_beg,t_end),
if fd is not file_or_fd : fd.close()
return ans
#################################################
# Segments related,
#
# Segments as 'Bool vectors' can be handy,
# - for 'superposing' the segmentations,
# - for frame-selection in Speaker-ID experiments,
def read_segments_as_bool_vec(segments_file):
""" [ bool_vec ] = read_segments_as_bool_vec(segments_file)
using kaldi 'segments' file for 1 wav, format : '<utt> <rec> <t-beg> <t-end>'
- t-beg, t-end is in seconds,
- assumed 100 frames/second,
"""
segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1)
# Sanity checks,
assert(len(segs) > 0) # empty segmentation is an error,
assert(len(np.unique([rec[1] for rec in segs ])) == 1) # segments with only 1 wav-file,
# Convert time to frame-indexes,
start = np.rint([100 * rec[2] for rec in segs]).astype(int)
end = np.rint([100 * rec[3] for rec in segs]).astype(int)
# Taken from 'read_lab_to_bool_vec', htk.py,
frms = np.repeat(np.r_[np.tile([False,True], len(end)), False],
np.r_[np.c_[start - np.r_[0, end[:-1]], end-start].flat, 0])
assert np.sum(end-start) == np.sum(frms)
return frms
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import json
import random
import re
import tarfile
from subprocess import PIPE, Popen
from urllib.parse import urlparse
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def url_opener(data):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
# TODO(Binbin Zhang): support HTTP
url = sample['src']
try:
pr = urlparse(url)
# local file
if pr.scheme == '' or pr.scheme == 'file':
stream = open(url, 'rb')
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
else:
cmd = f'wget -q -O - {url}'
process = Popen(cmd, shell=True, stdout=PIPE)
sample.update(process=process)
stream = process.stdout
sample.update(stream=stream)
yield sample
except Exception as ex:
logging.warning('Failed to open {}'.format(url))
def tar_file_and_group(data):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
for sample in data:
assert 'stream' in sample
stream = tarfile.open(fileobj=sample['stream'], mode="r|*")
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['key'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(file_obj)
example['wav'] = waveform
example['sample_rate'] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as ex:
valid = False
logging.warning('error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['key'] = prev_prefix
yield example
stream.close()
if 'process' in sample:
sample['process'].communicate()
sample['stream'].close()
def parse_raw(data):
""" Parse key/wav/txt from json line
Args:
data: Iterable[str], str is a json line has key/wav/txt
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
for sample in data:
assert 'src' in sample
json_line = sample['src']
obj = json.loads(json_line)
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
try:
if 'start' in obj:
assert 'end' in obj
sample_rate = torchaudio.backend.sox_io_backend.info(
wav_file).sample_rate
start_frame = int(obj['start'] * sample_rate)
end_frame = int(obj['end'] * sample_rate)
waveform, _ = torchaudio.backend.sox_io_backend.load(
filepath=wav_file,
num_frames=end_frame - start_frame,
frame_offset=start_frame)
else:
waveform, sample_rate = torchaudio.load(wav_file)
example = dict(key=key,
txt=txt,
wav=waveform,
sample_rate=sample_rate)
yield example
except Exception as ex:
logging.warning('Failed to read {}'.format(wav_file))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'label' in sample
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['label']) < token_min_length:
continue
if len(sample['label']) > token_max_length:
continue
if num_frames != 0:
if len(sample['label']) / num_frames < min_output_input_ratio:
continue
if len(sample['label']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
yield sample
def speed_perturb(data, speeds=None):
""" Apply speed perturb to the data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
speeds(List[float]): optional speed
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if speeds is None:
speeds = [0.9, 1.0, 1.1]
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
speed = random.choice(speeds)
if speed != 1.0:
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate,
[['speed', str(speed)], ['rate', str(sample_rate)]])
sample['wav'] = wav
yield sample
def compute_fbank(data,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_mfcc(data,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0,
num_ceps=40,
high_freq=0.0,
low_freq=20.0):
""" Extract mfcc
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.mfcc(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
num_ceps=num_ceps,
high_freq=high_freq,
low_freq=low_freq,
sample_frequency=sample_rate)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def __tokenize_by_bpe_model(sp, txt):
tokens = []
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r'([\u4e00-\u9fff])')
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(txt.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
tokens.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
tokens.append(p)
return tokens
def tokenize(data,
symbol_table,
bpe_model=None,
non_lang_syms=None,
split_with_space=False):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
if non_lang_syms is not None:
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
else:
non_lang_syms = {}
non_lang_syms_pattern = None
if bpe_model is not None:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
else:
sp = None
for sample in data:
assert 'txt' in sample
txt = sample['txt'].strip()
if non_lang_syms_pattern is not None:
parts = non_lang_syms_pattern.split(txt.upper())
parts = [w for w in parts if len(w.strip()) > 0]
else:
parts = [txt]
label = []
tokens = []
for part in parts:
if part in non_lang_syms:
tokens.append(part)
else:
if bpe_model is not None:
tokens.extend(__tokenize_by_bpe_model(sp, part))
else:
if split_with_space:
part = part.split(" ")
for ch in part:
if ch == ' ':
ch = "▁"
tokens.append(ch)
for ch in tokens:
if ch in symbol_table:
label.append(symbol_table[ch])
elif '<unk>' in symbol_table:
label.append(symbol_table['<unk>'])
sample['tokens'] = tokens
sample['label'] = label
yield sample
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
""" Do spec augmentation
Inplace operation
Args:
data: Iterable[{key, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
max_w: max width of time warp
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
y = x.clone().detach()
max_frames = y.size(0)
max_freq = y.size(1)
# time mask
for i in range(num_t_mask):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
y[start:end, :] = 0
# freq mask
for i in range(num_f_mask):
start = random.randint(0, max_freq - 1)
length = random.randint(1, max_f)
end = min(max_freq, start + length)
y[:, start:end] = 0
sample['feat'] = y
yield sample
def spec_sub(data, max_t=20, num_t_sub=3):
""" Do spec substitute
Inplace operation
Args:
data: Iterable[{key, feat, label}]
max_t: max width of time substitute
num_t_sub: number of time substitute to apply
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
y = x.clone().detach()
max_frames = y.size(0)
for i in range(num_t_sub):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
# only substitute the earlier time chosen randomly for current time
pos = random.randint(0, start)
y[start:end, :] = x[start - pos:end - pos, :]
sample['feat'] = y
yield sample
def spec_trim(data, max_t=20):
""" Trim tailing frames. Inplace operation.
ref: TrimTail [https://arxiv.org/abs/2211.00522]
Args:
data: Iterable[{key, feat, label}]
max_t: max width of length trimming
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
max_frames = x.size(0)
length = random.randint(1, max_t)
if length < max_frames / 2:
y = x.clone().detach()[:max_frames - length]
sample['feat'] = y
yield sample
def shuffle(data, shuffle_size=10000):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['feat'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['feat'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'feat' in sample
assert isinstance(sample['feat'], torch.Tensor)
new_sample_frames = sample['feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
""" Wrapper for static/dynamic batch
"""
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor(
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
sorted_labels = [
torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
]
label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
dtype=torch.int32)
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
padding_labels = pad_sequence(sorted_labels,
batch_first=True,
padding_value=-1)
yield (sorted_keys, padded_feats, padding_labels, feats_lengths,
label_lengths)
# Copyright (c) 2021 Mobvoi Inc (Chao Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import math
import torchaudio
import torch
torchaudio.set_audio_backend("sox_io")
def db2amp(db):
return pow(10, db / 20)
def amp2db(amp):
return 20 * math.log10(amp)
def make_poly_distortion(conf):
"""Generate a db-domain ploynomial distortion function
f(x) = a * x^m * (1-x)^n + x
Args:
conf: a dict {'a': #int, 'm': #int, 'n': #int}
Returns:
The ploynomial function, which could be applied on
a float amplitude value
"""
a = conf['a']
m = conf['m']
n = conf['n']
def poly_distortion(x):
abs_x = abs(x)
if abs_x < 0.000001:
x = x
else:
db_norm = amp2db(abs_x) / 100 + 1
if db_norm < 0:
db_norm = 0
db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm
if db_norm > 1:
db_norm = 1
db = (db_norm - 1) * 100
amp = db2amp(db)
if amp >= 0.9997:
amp = 0.9997
if x > 0:
x = amp
else:
x = -amp
return x
return poly_distortion
def make_quad_distortion():
return make_poly_distortion({'a' : 1, 'm' : 1, 'n' : 1})
# the amplitude are set to max for all non-zero point
def make_max_distortion(conf):
"""Generate a max distortion function
Args:
conf: a dict {'max_db': float }
'max_db': the maxium value.
Returns:
The max function, which could be applied on
a float amplitude value
"""
max_db = conf['max_db']
if max_db:
max_amp = db2amp(max_db) # < 0.997
else:
max_amp = 0.997
def max_distortion(x):
if x > 0:
x = max_amp
elif x < 0:
x = -max_amp
else:
x = 0.0
return x
return max_distortion
def make_amp_mask(db_mask=None):
"""Get a amplitude domain mask from db domain mask
Args:
db_mask: Optional. A list of tuple. if None, using default value.
Returns:
A list of tuple. The amplitude domain mask
"""
if db_mask is None:
db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)]
amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask]
return amp_mask
default_mask = make_amp_mask()
def generate_amp_mask(mask_num):
"""Generate amplitude domain mask randomly in [-100db, 0db]
Args:
mask_num: the slot number of the mask
Returns:
A list of tuple. each tuple defines a slot.
e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)]
for #mask_num = 4
"""
a = [0] * 2 * mask_num
a[0] = 0
m = []
for i in range(1, 2 * mask_num):
a[i] = a[i - 1] + random.uniform(0.5, 1)
max_val = a[2 * mask_num - 1]
for i in range(0, mask_num):
l = ((a[2 * i] - max_val) / max_val) * 100
r = ((a[2 * i + 1] - max_val) / max_val) * 100
m.append((l, r))
return make_amp_mask(m)
def make_fence_distortion(conf):
"""Generate a fence distortion function
In this fence-like shape function, the values in mask slots are
set to maxium, while the values not in mask slots are set to 0.
Use seperated masks for Positive and negetive amplitude.
Args:
conf: a dict {'mask_number': int,'max_db': float }
'mask_number': the slot number in mask.
'max_db': the maxium value.
Returns:
The fence function, which could be applied on
a float amplitude value
"""
mask_number = conf['mask_number']
max_db = conf['max_db']
max_amp = db2amp(max_db) # 0.997
if mask_number <= 0 :
positive_mask = default_mask
negative_mask = make_amp_mask([(-50, 0)])
else:
positive_mask = generate_amp_mask(mask_number)
negative_mask = generate_amp_mask(mask_number)
def fence_distortion(x):
is_in_mask = False
if x > 0:
for mask in positive_mask:
if x >= mask[0] and x <= mask[1]:
is_in_mask = True
return max_amp
if not is_in_mask:
return 0.0
elif x < 0:
abs_x = abs(x)
for mask in negative_mask:
if abs_x >= mask[0] and abs_x <= mask[1]:
is_in_mask = True
return max_amp
if not is_in_mask:
return 0.0
return x
return fence_distortion
#
def make_jag_distortion(conf):
"""Generate a jag distortion function
In this jag-like shape function, the values in mask slots are
not changed, while the values not in mask slots are set to 0.
Use seperated masks for Positive and negetive amplitude.
Args:
conf: a dict {'mask_number': #int}
'mask_number': the slot number in mask.
Returns:
The jag function,which could be applied on
a float amplitude value
"""
mask_number = conf['mask_number']
if mask_number <= 0 :
positive_mask = default_mask
negative_mask = make_amp_mask([(-50, 0)])
else:
positive_mask = generate_amp_mask(mask_number)
negative_mask = generate_amp_mask(mask_number)
def jag_distortion(x):
is_in_mask = False
if x > 0:
for mask in positive_mask:
if x >= mask[0] and x <= mask[1]:
is_in_mask = True
return x
if not is_in_mask:
return 0.0
elif x < 0:
abs_x = abs(x)
for mask in negative_mask:
if abs_x >= mask[0] and abs_x <= mask[1]:
is_in_mask = True
return x
if not is_in_mask:
return 0.0
return x
return jag_distortion
# gaining 20db means amp = amp * 10
# gaining -20db means amp = amp / 10
def make_gain_db(conf):
"""Generate a db domain gain function
Args:
conf: a dict {'db': #float}
'db': the gaining value
Returns:
The db gain function, which could be applied on
a float amplitude value
"""
db = conf['db']
def gain_db(x):
return min(0.997, x * pow(10, db / 20))
return gain_db
def distort(x, func, rate=0.8):
"""Distort a waveform in sample point level
Args:
x: the origin wavefrom
func: the distort function
rate: sample point-level distort probability
Returns:
the distorted waveform
"""
for i in range(0, x.shape[1]):
a = random.uniform(0, 1)
if a < rate:
x[0][i] = func(float(x[0][i]))
return x
def distort_chain(x, funcs, rate=0.8):
for i in range(0, x.shape[1]):
a = random.uniform(0, 1)
if a < rate:
for func in funcs:
x[0][i] = func(float(x[0][i]))
return x
# x is numpy
def distort_wav_conf(x, distort_type, distort_conf, rate=0.1):
if distort_type == 'gain_db':
gain_db = make_gain_db(distort_conf)
x = distort(x, gain_db)
elif distort_type == 'max_distortion':
max_distortion = make_max_distortion(distort_conf)
x = distort(x, max_distortion, rate=rate)
elif distort_type == 'fence_distortion':
fence_distortion = make_fence_distortion(distort_conf)
x = distort(x, fence_distortion, rate=rate)
elif distort_type == 'jag_distortion':
jag_distortion = make_jag_distortion(distort_conf)
x = distort(x, jag_distortion, rate=rate)
elif distort_type == 'poly_distortion':
poly_distortion = make_poly_distortion(distort_conf)
x = distort(x, poly_distortion, rate=rate)
elif distort_type == 'quad_distortion':
quad_distortion = make_quad_distortion()
x = distort(x, quad_distortion, rate=rate)
elif distort_type == 'none_distortion':
pass
else:
print('unsupport type')
return x
def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out):
x, sr = torchaudio.load(wav_in)
x = x.detach().numpy()
out = distort_wav_conf(x, distort_type, distort_conf, rate)
torchaudio.save(wav_out, torch.from_numpy(out), sr)
if __name__ == "__main__":
distort_type = sys.argv[1]
wav_in = sys.argv[2]
wav_out = sys.argv[3]
conf = None
rate = 0.1
if distort_type == 'new_jag_distortion':
conf = {'mask_number' : 4}
elif distort_type == 'new_fence_distortion':
conf = {'mask_number' : 1, 'max_db' : -30}
elif distort_type == 'poly_distortion':
conf = {'a' : 4, 'm' : 2, "n" : 2}
distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment