#!/usr/bin/env python3 import torch from torch.utils.data import DataLoader, Dataset import soundfile import time import numpy as np import os import multiprocessing import argparse from typing import Dict, Optional, Tuple from espnet2.bin.asr_inference import Speech2Text from espnet2.torch_utils.device_funcs import to_device torch.set_num_threads(1) try: from swig_decoders import map_batch, \ ctc_beam_search_decoder_batch, \ TrieVector, PathTrie except ImportError: print('Please install ctc decoders first by refering to\n' + 'https://github.com/Slyne/ctc_decoder.git') sys.exit(1) class CustomAishellDataset(Dataset): def __init__(self, wav_scp_file, text_file): with open(wav_scp_file,'r') as wav_scp, open(text_file,'r') as text: wavs = wav_scp.readlines() texts = text.readlines() self.wav_names = [item.split()[0] for item in wavs] self.wav_paths = [item.split()[1] for item in wavs] self.labels = ["".join(item.split()[1:]) for item in texts] def __len__(self): return len(self.labels) def __getitem__(self, idx): speech,sr = soundfile.read(self.wav_paths[idx]) assert sr==16000, sr speech = np.array(speech, dtype=np.float32) speech_len = speech.shape[0] label = self.labels[idx] name = self.wav_names[idx] return speech, speech_len, label, name def collate_wrapper(batch): speeches = np.zeros((len(batch), 16000 * 30),dtype=np.float32) lengths = np.zeros(len(batch),dtype=np.int64) labels = [] names = [] for i, (speech, speech_len, label, name) in enumerate(batch): speeches[i,:speech_len] = speech lengths[i] = speech_len labels.append(label) names.append(name) speeches = speeches[:,:max(lengths)] return speeches, lengths, labels, names def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part. Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask def process_batch_data(batch, speech2text): """Process batch data and prepare for inference""" speech, speech_lens, labels, names = batch audio_sample_len = np.sum(speech_lens) / 16000 batch_data = {"speech": speech, "speech_lengths": speech_lens} if isinstance(batch_data["speech"], np.ndarray): batch_data["speech"] = torch.tensor(batch_data["speech"]) if isinstance(batch_data["speech_lengths"], np.ndarray): batch_data["speech_lengths"] = torch.tensor(batch_data["speech_lengths"]) batch_data = to_device(batch_data, device='cuda') feats, encoder_out_lens = speech2text.asr_model.pre_data(**batch_data) encoder_out_lens = torch.ceil(encoder_out_lens.float() / 4).long() encoder_inputs = {'feats': feats.cpu().numpy().astype(np.float32)} return encoder_inputs, encoder_out_lens, labels, names, audio_sample_len def inference_step(encoder_inputs, encoder_out_lens, speech2text, full_lm_model, args, encoder_session): """Perform inference on prepared data""" # Run encoder inference encoder_outputs = encoder_session.run(None, encoder_inputs) encoder_out_numpy = encoder_outputs[0] encoder_out = torch.from_numpy(encoder_out_numpy).float().cuda() ctc_log_probs = torch.nn.functional.log_softmax( speech2text.asr_model.ctc.ctc_lo(encoder_out), dim=2 ) beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, args.beam_size, dim=2) num_processes = min(multiprocessing.cpu_count(), args.batch_size) if args.mode == 'ctc_greedy_search': assert args.beam_size != 1 log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) hyps = map_batch(batch_sents, speech2text.asr_model.token_list, num_processes, True, 0) else: batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() batch_log_probs_seq = [] batch_log_probs_ids = [] batch_start = [] # only effective in streaming deployment batch_root = TrieVector() root_dict = {} for i in range(len(batch_len_list)): num_sent = encoder_out.size()[1] batch_log_probs_seq.append( batch_log_probs_seq_list[i][0:num_sent]) batch_log_probs_ids.append( batch_log_probs_idx_list[i][0:num_sent]) root_dict[i] = PathTrie() batch_root.append(root_dict[i]) batch_start.append(True) score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_ids, batch_root, batch_start, args.beam_size, num_processes, 0, -2, 0.99999) if args.mode == 'ctc_prefix_beam_search': hyps = [] for cand_hyps in score_hyps: hyps.append(cand_hyps[0][1]) hyps = map_batch(hyps, speech2text.asr_model.token_list, num_processes, False, 0) elif args.mode == 'attention_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad_sos_eos = torch.ones( (args.batch_size, args.beam_size, max_len + 2), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_pad_sos = torch.ones( (args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.eos # FIXME: eos hyps_pad_eos = torch.ones( (args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens_sos = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) + 2 hyps_pad_sos_eos[i][j][0:l] = torch.tensor([speech2text.asr_model.sos] + cand + [speech2text.asr_model.eos]) hyps_pad_sos[i][j][0:l-1] = torch.tensor([speech2text.asr_model.sos] + cand) hyps_pad_eos[i][j][0:l-1] = torch.tensor(cand + [speech2text.asr_model.eos]) hyps_lens_sos[i][j] = len(cand) + 1 k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_out_lens = encoder_out_lens.repeat(bz) hyps_pad = hyps_pad_sos_eos.view(B2, max_len + 2) hyps_lens = hyps_lens_sos.view(B2,) hyps_pad_sos = hyps_pad_sos.view(B2, max_len + 1) hyps_pad_eos = hyps_pad_eos.view(B2, max_len + 1) decoder_out, _ = speech2text.asr_model.decoder(encoder_out,encoder_out_lens,hyps_pad_sos.cuda(), hyps_lens.cuda()) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) mask = ~make_pad_mask(hyps_lens, max_len+1) # B2 x T2 # mask index, remove ignore id index = torch.unsqueeze(hyps_pad_eos * mask, 2) score = decoder_out.cpu().gather(2, index).squeeze(2) # B2 X T2 # mask padded part score = score * mask # decoder_out = decoder_out.view(B, bz, max_len+1, -1) score = torch.sum(score, axis=1) score = torch.reshape(score,(B,bz)) all_scores = ctc_score + 0.1 * score # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) elif args.mode == 'attention_lm_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad = torch.ones( (args.batch_size, args.beam_size, max_len), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) hyps_pad[i][j][0:l] = torch.tensor(cand) hyps_lens[i][j] = len(cand) k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_out_lens = encoder_out_lens.repeat(bz) hyps_pad = hyps_pad.view(B2, max_len).cuda() hyps_lens = hyps_lens.view(B2,).cuda() decoder_scores = -speech2text.asr_model.batchify_nll( encoder_out, encoder_out_lens, hyps_pad, hyps_lens, 320 ) decoder_scores = torch.reshape(decoder_scores,(B,bz)).cpu() hyps_pad[hyps_pad == speech2text.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad, hyps_lens, 64) nnlm_scores = -nnlm_nll.sum(dim=1) nnlm_scores = torch.reshape(nnlm_scores,(B,bz)).cpu() all_scores = ctc_score - 0.05 * decoder_scores + 1.0 * nnlm_scores # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) elif args.mode == 'lm_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad = torch.ones( (args.batch_size, args.beam_size, max_len), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) hyps_pad[i][j][0:l] = torch.tensor(cand) hyps_lens[i][j] = len(cand) k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz hyps_pad = hyps_pad.view(B2, max_len).cuda() hyps_lens = hyps_lens.view(B2,).cuda() hyps_pad[hyps_pad == speech2text.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad, hyps_lens, 320) nnlm_scores = -nnlm_nll.sum(dim=1) nnlm_scores = torch.reshape(nnlm_scores,(B,bz)).cpu() all_scores = ctc_score + 0.9 * nnlm_scores # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) else: raise NotImplementedError return hyps def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--lm_config', required=True, help='config file') parser.add_argument('--gpu', type=int, default=0, help='gpu id for this rank, -1 for cpu') parser.add_argument('--wav_scp', required=True, help='wav scp file') parser.add_argument('--text', required=True, help='ground truth text file') parser.add_argument('--model_path', required=True, help='torch pt model file') parser.add_argument('--lm_path', required=True, help='torch pt model file') parser.add_argument('--result_file', default='./predictions.txt', help='asr result file') parser.add_argument('--log_file', default='./rtf.txt', help='asr decoding log') parser.add_argument('--batch_size', type=int, default=24, help='batch_size') parser.add_argument('--beam_size', type=int, default=10, help='beam_size') parser.add_argument('--mode', choices=[ 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring', 'attention_lm_rescoring', 'lm_rescoring'], default='attention_lm_rescoring', help='decoding mode') args = parser.parse_args() return args if __name__ == '__main__': args = get_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) dataset = CustomAishellDataset(args.wav_scp, args.text) test_data_loader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_wrapper) speech2text = Speech2Text( args.config, args.model_path, None, args.lm_config, args.lm_path, device="cuda" ) full_lm_model = None if args.lm_config is not None and args.lm_path is not None: from espnet2.tasks.lm import LMTask full_lm_model, _ = LMTask.build_model_from_file( args.lm_config, args.lm_path, "cuda" ) full_lm_model.eval() import onnxruntime as ort providers = ['ROCMExecutionProvider'] encoder_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24_1/transformer_lm/full/default_encoder_fp16.onnx" encoder_session = ort.InferenceSession(encoder_path, providers=providers) output_names = ["encoder_out", "encoder_out_lens"] # Warmup: Run inference on first batch to initialize models and cache print("Starting warmup...") warmup_start = time.time() with torch.no_grad(): for i, batch in enumerate(test_data_loader): if i >= 1: # Warmup with first batch only break # Process batch data encoder_inputs, encoder_out_lens, labels, names, audio_sample_len = process_batch_data(batch, speech2text) # Run inference hyps = inference_step(encoder_inputs, encoder_out_lens, speech2text, full_lm_model, args, encoder_session) print(f"Warmup completed in {time.time() - warmup_start:.2f} seconds") # Main inference loop time_start = time.perf_counter() audio_sample_len_total = 0 infer_times = [] total_infer_times = [] total_start = time.time() # Open files for saving results in the required format with torch.no_grad(), open(args.result_file, 'w') as fout, open('ref.trn', 'w') as ref_file, open('hyp.trn', 'w') as hyp_file: for batch_idx, batch in enumerate(test_data_loader): # Process batch data (separated from inference) encoder_inputs, encoder_out_lens, labels, names, audio_sample_len = process_batch_data(batch, speech2text) audio_sample_len_total += audio_sample_len # Measure inference time infer_start = time.time() # Run inference hyps = inference_step(encoder_inputs, encoder_out_lens, speech2text, full_lm_model, args, encoder_session) infer_time = time.time() - infer_start infer_times.append(infer_time) # Save results for i, key in enumerate(names): content = hyps[i] # print('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) # Save to ref.trn and hyp.trn in the required format # Convert continuous Chinese text to space-separated characters ref_text = ' '.join(labels[i]) hyp_text = ' '.join(content) ref_file.write('{} \t ({})\n'.format(ref_text, key)) hyp_file.write('{} \t ({})\n'.format(hyp_text, key)) # print(f"Batch {batch_idx + 1} processed in {infer_time:.3f} seconds") total_infer_times.append(time.time() - total_start) total_start = time.time() # Calculate and print statistics time_end = time.perf_counter() - time_start # Exclude first few batches for warmup # if len(infer_times) > 5: # stable_infer_times = infer_times[5:] # mean_infer_time = np.mean(stable_infer_times) # print(f"Average inference time (excluding warmup): {mean_infer_time:.3f} seconds") # print(f"FPS: {args.batch_size/mean_infer_time:.1f}") print(f"Total audio processed: {audio_sample_len_total:.1f} seconds") print(f"Total time: {time_end:.1f} seconds") print(f"Real-time factor (RTF): {time_end/audio_sample_len_total:.3f}") print("***************************") infer_time = sum(infer_times) avg_infer_fps = 24 * len(infer_times) / sum(infer_times) print(f"total_infer_time: {infer_time}s") print(f'avg_infer_fps: {avg_infer_fps}samples/s') load_data_infer_time = sum(total_infer_times) load_data_avg_infer_fps = len(total_infer_times) * 24 / sum(total_infer_times) print(f'load_data_total_infer_time: {load_data_infer_time}s') print(f'load_data_avg_total_Infer_fps: {load_data_avg_infer_fps} samples/s') print("******************************") with open(args.log_file, 'w') as log: log.write(f"Decoding audio {audio_sample_len_total} secs, cost {time_end} secs, RTF: {time_end/audio_sample_len_total}, process {audio_sample_len_total/time_end} secs audio per second, decoding args: {args}")