# Copyright (c) 2022 Horizon Inc. (author: Xingchen Song) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import argparse import logging import librosa import torch import torchaudio import yaml import numpy as np import matplotlib.pyplot as plt import matplotlib.font_manager as fm import torchaudio.compliance.kaldi as kaldi from wenet.utils.init_model import init_model from wenet.utils.checkpoint import load_checkpoint from wenet.utils.file_utils import read_symbol_table from wenet.utils.mask import make_pad_mask from wenet.utils.common import replace_duplicates_with_blank def get_args(): parser = argparse.ArgumentParser( description='Analyze latency and plot CTC-Spike.') parser.add_argument('--config', required=True, type=str, help='configration') parser.add_argument('--gpu', type=int, default=0, help='gpu id for this rank, -1 for cpu') parser.add_argument('--ckpt', required=True, type=str, help='model checkpoint') parser.add_argument('--tag', required=True, type=str, help='image subtitle') parser.add_argument('--wavscp', required=True, type=str, help='wav.scp') parser.add_argument('--alignment', required=True, type=str, help='force alignment, generated by Kaldi.') parser.add_argument('--chunk_size', required=True, type=int, help='chunk size') parser.add_argument('--left_chunks', default=-1, type=int, help='left chunks') parser.add_argument('--font', required=True, type=str, help='font file') parser.add_argument('--dict', required=True, type=str, help='dict file') parser.add_argument('--result_dir', required=True, type=str, help='saving pdf') parser.add_argument('--model_type', default='ctc', choices=['ctc', 'transducer'], help='show latency metrics from ctc models or rnn-t models') args = parser.parse_args() return args def main(): args = get_args() logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') torch.manual_seed(777) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) symbol_table = read_symbol_table(args.dict) char_dict = {v: k for k, v in symbol_table.items()} # 1. Load model with open(args.config, 'r') as fin: conf = yaml.load(fin, Loader=yaml.FullLoader) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = init_model(conf) load_checkpoint(model, args.ckpt) model = model.eval().to(device) subsampling = model.encoder.embed.subsampling_rate eos = model.eos_symbol() with open(args.wavscp, 'r') as fin: wavs = fin.readlines() # 2. Forward model (get streaming_timestamps) timestamps = {} for idx, wav in enumerate(wavs): if idx % 100 == 0: logging.info("processed {}.".format(idx)) key, wav = wav.strip().split(' ', 1) waveform, sr = torchaudio.load(wav) resample_rate = conf['dataset_conf']['resample_conf']['resample_rate'] waveform = torchaudio.transforms.Resample( orig_freq=sr, new_freq=resample_rate)(waveform) waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.fbank( waveform, num_mel_bins=conf['dataset_conf']['fbank_conf']['num_mel_bins'], frame_length=conf['dataset_conf']['fbank_conf']['frame_length'], frame_shift=conf['dataset_conf']['fbank_conf']['frame_shift'], dither=0.0, energy_floor=0.0, sample_frequency=resample_rate, ) speech = mat.unsqueeze(0).to(device) speech_lengths = torch.tensor([mat.size(0)]).to(device) # Let's assume batch_size = 1 encoder_out, encoder_mask = model.encoder( speech, speech_lengths, args.chunk_size, args.left_chunks) maxlen = encoder_out.size(1) # (B, maxlen, encoder_dim) encoder_out_lens = encoder_mask.squeeze(1).sum(1) # CTC greedy search if args.model_type == 'ctc': ctc_probs = model.ctc.log_softmax( encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(1, maxlen) # (B, maxlen) topk_prob = topk_prob.view(1, maxlen) # (B, maxlen) mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen) topk_index = topk_index.masked_fill_(mask, eos) # (B, maxlen) topk_prob = topk_prob.masked_fill_(mask, 0.0) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] hyps = [replace_duplicates_with_blank(hyp) for hyp in hyps] scores = [prob.tolist() for prob in topk_prob] timestamps[key] = [hyps[0], scores[0], wav] if args.model_type == 'transducer': hyps = [] scores = [] # fake padding padding = torch.zeros(1, 1).to(encoder_out.device) # sos pred_input_step = torch.tensor([model.blank]).reshape(1, 1) cache = model.predictor.init_state(1, method="zero", device=encoder_out.device) new_cache: List[torch.Tensor] = [] t = 0 hyps = [] prev_out_nblk = True pred_out_step = None per_frame_max_noblk = 1 per_frame_noblk = 0 while t < encoder_out_lens: encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E] if prev_out_nblk: step_outs = model.predictor.forward_step(pred_input_step, padding, cache) pred_out_step, new_cache = step_outs[0], step_outs[1] joint_out_step = model.joint(encoder_out_step, pred_out_step) # [1,1,v] joint_out_probs = joint_out_step.log_softmax(dim=-1) scores.append(torch.max(joint_out_probs).item()) joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # [] if joint_out_max != model.blank: hyps.append(joint_out_max.item()) prev_out_nblk = True per_frame_noblk = per_frame_noblk + 1 pred_input_step = joint_out_max.reshape(1, 1) # state_m, state_c = clstate_out_m, state_out_c cache = new_cache if joint_out_max == model.blank or \ per_frame_noblk >= per_frame_max_noblk: if joint_out_max == model.blank: prev_out_nblk = False hyps.append(model.blank) # TODO(Mddct): make t in chunk for streamming # or t should't be too lang to predict none blank t = t + 1 per_frame_noblk = 0 timestamps[key] = [hyps, scores, wav] # 3. Analyze latency with open(args.alignment, 'r') as fin: aligns = fin.readlines() not_found, len_unequal, ignored = 0, 0, 0 datas = [] for align in aligns: key, align = align.strip().split(' ', 1) if key not in timestamps: not_found += 1 continue fa, st = [], [] # force_alignment, streaming_timestamps text_fa, text_st = "", "" for i, token in enumerate(align.split()): if token != '': text_fa += token # NOTE(xcsong): W/O subsample fa.append(i * 10) # ignore alignment_errors >= 70ms frames_fa = len(align.split()) frames_st = len(timestamps[key][0]) * subsampling if abs(frames_st - frames_fa) >= 7: ignored += 1 continue for i, token_id in enumerate(timestamps[key][0]): if token_id != 0: text_st += char_dict[token_id] # NOTE(xcsong): W subsample st.append(i * subsampling * 10) if len(fa) != len(st): len_unequal += 1 continue # datas[i] = [key, text_fa, text_st, list_of_diff, # FirstTokenDelay, LastTokenDelay, AvgTokenDelay, # streaming_timestamps, force_alignment] datas.append([key, text_fa, text_st, [a - b for a, b in zip(st, fa)], st[0] - fa[0], st[-1] - fa[-1], (sum(st) - sum(fa)) / len(st), timestamps[key], align.split()]) logging.info("not found: {}, length unequal: {}, ignored: {}, \ valid samples: {}".format(not_found, len_unequal, ignored, len(datas))) # 4. Plot and print num_datas = len(datas) names = ['FirstTokenDelay', 'LastTokenDelay', 'AvgTokenDelay'] names_index = [4, 5, 6] parts = ['max', 'P90', 'P75', 'P50', 'P25', 'min'] parts_index = [num_datas - 1, int(num_datas * 0.90), int(num_datas * 0.75), int(num_datas * 0.50), int(num_datas * 0.25), 0] for name, name_idx in zip(names, names_index): def f(name_idx=name_idx): return name_idx datas.sort(key=lambda x: x[f()]) logging.info("==========================") for p, i in zip(parts, parts_index): data = datas[i] # i.e., LastTokenDelay P90: 270.000 ms (wav_id: BAC009S0902W0144) logging.info("{} {}: {:.3f} ms (wav_id: {})".format( name, p, data[f()], datas[i][0])) font = fm.FontProperties(fname=args.font) plt.rcParams['axes.unicode_minus'] = False # we will have 2 sub-plots (force-align + streaming timestamps) # plus one wav-plot fig, axes = plt.subplots(figsize=(60, 60), nrows=3, ncols=1) for j in range(2): if j == 0: # subplot-0: streaming_timestamps plt_prefix = args.tag + "_" + name + "_" + p x = np.arange(len(data[7][0])) * subsampling hyps, scores = data[7][0], data[7][1] else: # subplot-1: force_alignments plt_prefix = "force_alignment" x = np.arange(len(data[8])) hyps = [symbol_table[d] for d in data[8]] scores = [0.0] * len(data[8]) axes[j].set_title(plt_prefix, fontsize=30) for frame, token, prob in zip(x, hyps, scores): if char_dict[token] != '': axes[j].bar( frame, np.exp(prob), label='{} {:.3f}'.format( char_dict[token], np.exp(prob)), ) axes[j].text( frame, np.exp(prob), '{} {:.3f} {}'.format( char_dict[token], np.exp(prob), frame), fontdict=dict(fontsize=24), fontproperties=font, ) else: axes[j].bar( frame, 0.01, label='{} {:.3f}'.format( char_dict[token], np.exp(prob)), ) axes[j].tick_params(labelsize=25) # subplot-2: wav # wav, hardcode sample_rate to 16000 samples, sr = librosa.load(data[7][2], sr=16000) time = np.arange(0, len(samples)) * (1.0 / sr) axes[-1].plot(time, samples) # i.e., RESULT_DIR/LTD_P90_120ms_BAC009S0768W0342.pdf plt.savefig(args.result_dir + "/" + name + "_" + p + "_" + str(data[f()]) + "ms" + "_" + data[0] + ".pdf") if __name__ == '__main__': main()