import torch from torch.utils.data import DataLoader, Dataset import soundfile import time import numpy as np import os import multiprocessing import argparse from typing import Dict, Optional, Tuple from espnet2.bin.asr_inference import Speech2Text from espnet2.torch_utils.device_funcs import to_device torch.set_num_threads(1) try: from swig_decoders import map_batch, \ ctc_beam_search_decoder_batch, \ TrieVector, PathTrie except ImportError: print('Please install ctc decoders first by refering to\n' + 'https://github.com/Slyne/ctc_decoder.git') sys.exit(1) class CustomAishellDataset(Dataset): def __init__(self, wav_scp_file, text_file): with open(wav_scp_file,'r') as wav_scp, open(text_file,'r') as text: wavs = wav_scp.readlines() texts = text.readlines() self.wav_names = [item.split()[0] for item in wavs] self.wav_paths = [item.split()[1] for item in wavs] self.labels = ["".join(item.split()[1:]) for item in texts] def __len__(self): return len(self.labels) def __getitem__(self, idx): speech,sr = soundfile.read(self.wav_paths[idx]) assert sr==16000, sr speech = np.array(speech, dtype=np.float32) speech_len = speech.shape[0] label = self.labels[idx] name = self.wav_names[idx] return speech, speech_len, label, name def collate_wrapper(batch): speeches = np.zeros((len(batch), 16000 * 30),dtype=np.float32) lengths = np.zeros(len(batch),dtype=np.int64) labels = [] names = [] for i, (speech, speech_len, label, name) in enumerate(batch): speeches[i,:speech_len] = speech lengths[i] = speech_len labels.append(label) names.append(name) speeches = speeches[:,:max(lengths)] return speeches, lengths, labels, names def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part. Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--lm_config', required=True, help='config file') parser.add_argument('--gpu', type=int, default=0, help='gpu id for this rank, -1 for cpu') parser.add_argument('--wav_scp', required=True, help='wav scp file') parser.add_argument('--text', required=True, help='ground truth text file') parser.add_argument('--model_path', required=True, help='torch pt model file') parser.add_argument('--lm_path', required=True, help='torch pt model file') parser.add_argument('--result_file', default='./predictions.txt', help='asr result file') parser.add_argument('--log_file', default='./rtf.txt', help='asr decoding log') parser.add_argument('--batch_size', type=int, default=24, help='batch_size') parser.add_argument('--beam_size', type=int, default=10, help='beam_size') parser.add_argument('--mode', choices=[ 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring', 'attention_lm_rescoring', 'lm_rescoring'], default='attention_lm_rescoring', help='decoding mode') args = parser.parse_args() return args if __name__ == '__main__': args = get_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) dataset = CustomAishellDataset(args.wav_scp, args.text) test_data_loader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_wrapper) speech2text = Speech2Text( args.config, args.model_path, None, args.lm_config, args.lm_path, device="cuda" ) # 手动加载完整的ESPnetLanguageModel对象 # 因为Speech2Text中只存储了原始语言模型,我们需要完整的对象来使用batchify_nll方法 full_lm_model = None if args.lm_config is not None and args.lm_path is not None: from espnet2.tasks.lm import LMTask full_lm_model, _ = LMTask.build_model_from_file( args.lm_config, args.lm_path, "cuda" ) full_lm_model.eval() count_times = [] time_start = time.perf_counter() audio_sample_len = 0 with torch.no_grad(), open(args.result_file, 'w') as fout: for _, batch in enumerate(test_data_loader): speech, speech_lens, labels, names = batch audio_sample_len += np.sum(speech_lens) / 16000 batch = {"speech": speech, "speech_lengths": speech_lens} if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device='cuda') # b. Forward Encoder # enc: [N, T, C] feats, feats_lengths = speech2text.asr_model.pre_data(**batch) feats_lengths_1 = torch.ceil(feats_lengths.float() / 4).long() print("feats_lengths_1:",feats_lengths_1) # print("feats_lengths:",feats_lengths) ll_time = time.time() encoder_out, encoder_out_lens = speech2text.asr_model.encode(feats, feats_lengths) print("encoder_out_lens:",encoder_out_lens) # ctc_log_probs: [N, T, C] ctc_log_probs = torch.nn.functional.log_softmax( speech2text.asr_model.ctc.ctc_lo(encoder_out), dim=2 ) beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, args.beam_size, dim=2) num_processes = min(multiprocessing.cpu_count(), args.batch_size) if args.mode == 'ctc_greedy_search': assert args.beam_size != 1 log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) hyps = map_batch(batch_sents, speech2text.asr_model.token_list, num_processes, True, 0) else: batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() batch_log_probs_seq = [] batch_log_probs_ids = [] batch_start = [] # only effective in streaming deployment batch_root = TrieVector() root_dict = {} for i in range(len(batch_len_list)): num_sent = batch_len_list[i] batch_log_probs_seq.append( batch_log_probs_seq_list[i][0:num_sent]) batch_log_probs_ids.append( batch_log_probs_idx_list[i][0:num_sent]) root_dict[i] = PathTrie() batch_root.append(root_dict[i]) batch_start.append(True) score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_ids, batch_root, batch_start, args.beam_size, num_processes, 0, -2, 0.99999) if args.mode == 'ctc_prefix_beam_search': hyps = [] for cand_hyps in score_hyps: hyps.append(cand_hyps[0][1]) hyps = map_batch(hyps, speech2text.asr_model.token_list, num_processes, False, 0) elif args.mode == 'attention_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad_sos_eos = torch.ones( (args.batch_size, args.beam_size, max_len + 2), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_pad_sos = torch.ones( (args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.eos # FIXME: eos hyps_pad_eos = torch.ones( (args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens_sos = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) + 2 hyps_pad_sos_eos[i][j][0:l] = torch.tensor([speech2text.asr_model.sos] + cand + [speech2text.asr_model.eos]) hyps_pad_sos[i][j][0:l-1] = torch.tensor([speech2text.asr_model.sos] + cand) hyps_pad_eos[i][j][0:l-1] = torch.tensor(cand + [speech2text.asr_model.eos]) hyps_lens_sos[i][j] = len(cand) + 1 k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_out_lens = encoder_out_lens.repeat(bz) hyps_pad = hyps_pad_sos_eos.view(B2, max_len + 2) hyps_lens = hyps_lens_sos.view(B2,) hyps_pad_sos = hyps_pad_sos.view(B2, max_len + 1) hyps_pad_eos = hyps_pad_eos.view(B2, max_len + 1) #hyps_pad_sos = hyps_pad[:, :-1] #hyps_pad_eos = hyps_pad[:, 1:] decoder_out, _ = speech2text.asr_model.decoder(encoder_out,encoder_out_lens,hyps_pad_sos.cuda(), hyps_lens.cuda()) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) mask = ~make_pad_mask(hyps_lens, max_len+1) # B2 x T2 # mask index, remove ignore id index = torch.unsqueeze(hyps_pad_eos * mask, 2) score = decoder_out.cpu().gather(2, index).squeeze(2) # B2 X T2 # mask padded part score = score * mask # decoder_out = decoder_out.view(B, bz, max_len+1, -1) score = torch.sum(score, axis=1) score = torch.reshape(score,(B,bz)) all_scores = ctc_score + 0.1 * score # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) elif args.mode == 'attention_lm_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad = torch.ones( (args.batch_size, args.beam_size, max_len), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) hyps_pad[i][j][0:l] = torch.tensor(cand) hyps_lens[i][j] = len(cand) k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_out_lens = encoder_out_lens.repeat(bz) hyps_pad = hyps_pad.view(B2, max_len).cuda() hyps_lens = hyps_lens.view(B2,).cuda() decoder_scores = -speech2text.asr_model.batchify_nll( encoder_out, encoder_out_lens, hyps_pad, hyps_lens, 320 ) decoder_scores = torch.reshape(decoder_scores,(B,bz)).cpu() hyps_pad[hyps_pad == speech2text.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad, hyps_lens, 64) nnlm_scores = -nnlm_nll.sum(dim=1) nnlm_scores = torch.reshape(nnlm_scores,(B,bz)).cpu() all_scores = ctc_score - 0.05 * decoder_scores + 1.0 * nnlm_scores # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) elif args.mode == 'lm_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad = torch.ones( (args.batch_size, args.beam_size, max_len), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) hyps_pad[i][j][0:l] = torch.tensor(cand) hyps_lens[i][j] = len(cand) k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz hyps_pad = hyps_pad.view(B2, max_len).cuda() hyps_lens = hyps_lens.view(B2,).cuda() hyps_pad[hyps_pad == speech2text.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad, hyps_lens, 320) nnlm_scores = -nnlm_nll.sum(dim=1) nnlm_scores = torch.reshape(nnlm_scores,(B,bz)).cpu() all_scores = ctc_score + 0.9 * nnlm_scores # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) count_time = time.time() - ll_time count_times.append(count_time) else: raise NotImplementedError for i, key in enumerate(names): content = hyps[i] # print('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) time_end = time.perf_counter() - time_start count_times = count_times[5:] mean_count_time = np.mean(count_times) print("平均 mean_count_time:", mean_count_time, " fps: ", 24/mean_count_time) # if str(args.gpu) == '0': with open(args.log_file, 'w') as log: log.write(f"Decoding audio {audio_sample_len} secs, cost {time_end} secs, RTF: {time_end/audio_sample_len}, process {audio_sample_len/time_end} secs audio per second, decoding args: {args}")