import torch from torch.utils.data import DataLoader, Dataset import soundfile import time import numpy as np import os import multiprocessing import argparse from typing import Dict, Optional, Tuple from espnet2.bin.asr_inference import Speech2Text from espnet2.torch_utils.device_funcs import to_device torch.set_num_threads(1) try: from swig_decoders import map_batch, \ ctc_beam_search_decoder_batch, \ TrieVector, PathTrie except ImportError: print('Please install ctc decoders first by refering to\n' + 'https://github.com/Slyne/ctc_decoder.git') sys.exit(1) def lm_batchify_nll(lm_scorer, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: """Compute negative log likelihood(nll) from transformer language model using lm_scorer To avoid OOM, this function separates the input into batches. Then call batch_score for each batch and combine and return results. Args: lm_scorer: Language model scorer object text: (Batch, Length) text_lengths: (Batch,) batch_size: int, samples each batch contain when computing nll, you may change this to avoid OOM or increase """ total_num = text.size(0) if total_num <= batch_size: nll, x_lengths = _compute_nll_with_lm_scorer(lm_scorer, text, text_lengths) else: nlls = [] x_lengths = [] max_length = text_lengths.max() start_idx = 0 while True: end_idx = min(start_idx + batch_size, total_num) batch_text = text[start_idx:end_idx, :] batch_text_lengths = text_lengths[start_idx:end_idx] # batch_nll: [B * T] batch_nll, batch_x_lengths = _compute_nll_with_lm_scorer( lm_scorer, batch_text, batch_text_lengths, max_length=max_length ) nlls.append(batch_nll) x_lengths.append(batch_x_lengths) start_idx = end_idx if start_idx == total_num: break nll = torch.cat(nlls) x_lengths = torch.cat(x_lengths) assert nll.size(0) == total_num assert x_lengths.size(0) == total_num return nll, x_lengths def _compute_nll_with_lm_scorer(lm_scorer, text: torch.Tensor, text_lengths: torch.Tensor, max_length: int = None) -> Tuple[torch.Tensor, torch.Tensor]: """Compute negative log likelihood using lm_scorer's score method This function simulates the nll method using the available score method from the lm_scorer object. """ batch_size = text.size(0) # For data parallel if max_length is None: text = text[:, : text_lengths.max()] else: text = text[:, :max_length] # Initialize nll for each sequence nll = torch.zeros(batch_size, device=text.device) # Process each sequence individually for batch_idx in range(batch_size): seq_text = text[batch_idx] seq_length = text_lengths[batch_idx] # Truncate to actual sequence length seq_text = seq_text[:seq_length] # Initialize state for this sequence state = None # Process each token position sequentially for pos in range(len(seq_text) - 1): # Get current token current_token = seq_text[pos].unsqueeze(0) # shape: (1,) # Score the current token logp, state = lm_scorer.score(current_token, state, None) # Get the ground truth next token next_token = seq_text[pos + 1] # Get the negative log likelihood for the correct next token token_nll = -logp[next_token] nll[batch_idx] += token_nll # x_lengths is text_lengths - 1 (since we score transitions between tokens) x_lengths = text_lengths - 1 x_lengths = torch.clamp(x_lengths, min=0) # Ensure non-negative return nll, x_lengths class CustomAishellDataset(Dataset): def __init__(self, wav_scp_file, text_file): with open(wav_scp_file,'r') as wav_scp, open(text_file,'r') as text: wavs = wav_scp.readlines() texts = text.readlines() self.wav_names = [item.split()[0] for item in wavs] self.wav_paths = [item.split()[1] for item in wavs] self.labels = ["".join(item.split()[1:]) for item in texts] def __len__(self): return len(self.labels) def __getitem__(self, idx): speech,sr = soundfile.read(self.wav_paths[idx]) assert sr==16000, sr speech = np.array(speech, dtype=np.float32) speech_len = speech.shape[0] label = self.labels[idx] name = self.wav_names[idx] return speech, speech_len, label, name def collate_wrapper(batch): speeches = np.zeros((len(batch), 16000 * 30),dtype=np.float32) lengths = np.zeros(len(batch),dtype=np.int64) labels = [] names = [] for i, (speech, speech_len, label, name) in enumerate(batch): speeches[i,:speech_len] = speech lengths[i] = speech_len labels.append(label) names.append(name) speeches = speeches[:,:max(lengths)] return speeches, lengths, labels, names def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part. Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--lm_config', required=True, help='config file') parser.add_argument('--gpu', type=int, default=0, help='gpu id for this rank, -1 for cpu') parser.add_argument('--wav_scp', required=True, help='wav scp file') parser.add_argument('--text', required=True, help='ground truth text file') parser.add_argument('--model_path', required=True, help='torch pt model file') parser.add_argument('--lm_path', required=True, help='torch pt model file') parser.add_argument('--result_file', default='./predictions.txt', help='asr result file') parser.add_argument('--log_file', default='./rtf.txt', help='asr decoding log') parser.add_argument('--batch_size', type=int, default=24, help='batch_size') parser.add_argument('--beam_size', type=int, default=10, help='beam_size') parser.add_argument('--mode', choices=[ 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring', 'attention_lm_rescoring', 'lm_rescoring'], default='attention_lm_rescoring', help='decoding mode') args = parser.parse_args() return args if __name__ == '__main__': args = get_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) dataset = CustomAishellDataset(args.wav_scp, args.text) test_data_loader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_wrapper) speech2text = Speech2Text( args.config, args.model_path, None, args.lm_config, args.lm_path, device="cuda" ) # 手动加载完整的ESPnetLanguageModel对象 # 因为Speech2Text中只存储了原始语言模型,我们需要完整的对象来使用batchify_nll方法 full_lm_model = None if args.lm_config is not None and args.lm_path is not None: from espnet2.tasks.lm import LMTask full_lm_model, _ = LMTask.build_model_from_file( args.lm_config, args.lm_path, "cuda" ) full_lm_model.eval() # 使用torch.compile优化模型性能 # 检查PyTorch版本是否支持torch.compile if hasattr(torch, 'compile') and torch.cuda.is_available(): print("启用torch.compile优化...") # 尝试不同的后端,从最兼容到最高性能 backends_to_try = [ ("aot_eager", {}), # aot_eager不支持mode参数 ("eager", {"mode": "reduce-overhead"}), ("inductor", {"mode": "reduce-overhead", "dynamic": False, "fullgraph": False}) ] for backend_name, backend_options in backends_to_try: try: print(f"尝试使用 {backend_name} 后端进行编译...") # 编译ASR模型的关键组件 if hasattr(speech2text.asr_model, 'encode'): speech2text.asr_model.encode = torch.compile(speech2text.asr_model.encode, backend=backend_name, **backend_options) if hasattr(speech2text.asr_model.ctc, 'ctc_lo'): speech2text.asr_model.ctc.ctc_lo = torch.compile(speech2text.asr_model.ctc.ctc_lo, backend=backend_name, **backend_options) # 编译语言模型(如果存在) if full_lm_model is not None and hasattr(full_lm_model, 'batchify_nll'): full_lm_model.batchify_nll = torch.compile(full_lm_model.batchify_nll, backend=backend_name, **backend_options) # 编译成功,设置TensorFloat-32加速 torch.set_float32_matmul_precision('high') print(f"✓ 使用 {backend_name} 后端编译成功") print("✓ TensorFloat-32加速已启用") break except Exception as e: print(f"⚠ {backend_name} 后端编译失败: {e}") # 恢复原始函数 if hasattr(speech2text.asr_model, 'encode'): speech2text.asr_model.encode = speech2text.asr_model.encode._orig_mod if hasattr(speech2text.asr_model.encode, '_orig_mod') else speech2text.asr_model.encode if hasattr(speech2text.asr_model.ctc, 'ctc_lo'): speech2text.asr_model.ctc.ctc_lo = speech2text.asr_model.ctc.ctc_lo._orig_mod if hasattr(speech2text.asr_model.ctc.ctc_lo, '_orig_mod') else speech2text.asr_model.ctc.ctc_lo if full_lm_model is not None and hasattr(full_lm_model, 'batchify_nll'): full_lm_model.batchify_nll = full_lm_model.batchify_nll._orig_mod if hasattr(full_lm_model.batchify_nll, '_orig_mod') else full_lm_model.batchify_nll if backend_name == backends_to_try[-1][0]: # 所有后端都失败 print("⚠ 所有编译后端都失败,将使用未编译模式运行") torch.set_float32_matmul_precision('high') # 仍然启用TF32加速 print("✓ TensorFloat-32加速已启用(未编译模式)") audio_sample_len = 0 total_inference_time = 0 with torch.no_grad(), open(args.result_file, 'w') as fout: for _, batch in enumerate(test_data_loader): # 开始计时推理时间(不包含torch.compile时间) batch_start_time = time.perf_counter() speech, speech_lens, labels, names = batch audio_sample_len += np.sum(speech_lens) / 16000 batch = {"speech": speech, "speech_lengths": speech_lens} if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device='cuda') # b. Forward Encoder # enc: [N, T, C] ll = time.time() encoder_out, encoder_out_lens = speech2text.asr_model.encode(**batch) # ctc_log_probs: [N, T, C] ctc_logits = speech2text.asr_model.ctc.ctc_lo(encoder_out) ctc_log_probs = torch.nn.functional.log_softmax(ctc_logits, dim=2) beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, args.beam_size, dim=2) num_processes = min(multiprocessing.cpu_count(), args.batch_size) if args.mode == 'ctc_greedy_search': assert args.beam_size != 1 log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) hyps = map_batch(batch_sents, speech2text.asr_model.token_list, num_processes, True, 0) else: batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() batch_log_probs_seq = [] batch_log_probs_ids = [] batch_start = [] # only effective in streaming deployment batch_root = TrieVector() root_dict = {} for i in range(len(batch_len_list)): num_sent = batch_len_list[i] batch_log_probs_seq.append( batch_log_probs_seq_list[i][0:num_sent]) batch_log_probs_ids.append( batch_log_probs_idx_list[i][0:num_sent]) root_dict[i] = PathTrie() batch_root.append(root_dict[i]) batch_start.append(True) score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_ids, batch_root, batch_start, args.beam_size, num_processes, 0, -2, 0.99999) if args.mode == 'ctc_prefix_beam_search': hyps = [] for cand_hyps in score_hyps: hyps.append(cand_hyps[0][1]) hyps = map_batch(hyps, speech2text.asr_model.token_list, num_processes, False, 0) elif args.mode == 'attention_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad_sos_eos = torch.ones( (args.batch_size, args.beam_size, max_len + 2), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_pad_sos = torch.ones( (args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.eos # FIXME: eos hyps_pad_eos = torch.ones( (args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens_sos = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) + 2 hyps_pad_sos_eos[i][j][0:l] = torch.tensor([speech2text.asr_model.sos] + cand + [speech2text.asr_model.eos]) hyps_pad_sos[i][j][0:l-1] = torch.tensor([speech2text.asr_model.sos] + cand) hyps_pad_eos[i][j][0:l-1] = torch.tensor(cand + [speech2text.asr_model.eos]) hyps_lens_sos[i][j] = len(cand) + 1 k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_out_lens = encoder_out_lens.repeat(bz) hyps_pad = hyps_pad_sos_eos.view(B2, max_len + 2) hyps_lens = hyps_lens_sos.view(B2,) hyps_pad_sos = hyps_pad_sos.view(B2, max_len + 1) hyps_pad_eos = hyps_pad_eos.view(B2, max_len + 1) #hyps_pad_sos = hyps_pad[:, :-1] #hyps_pad_eos = hyps_pad[:, 1:] decoder_out, _ = speech2text.asr_model.decoder(encoder_out,encoder_out_lens,hyps_pad_sos.cuda(), hyps_lens.cuda()) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) mask = ~make_pad_mask(hyps_lens, max_len+1) # B2 x T2 # mask index, remove ignore id index = torch.unsqueeze(hyps_pad_eos * mask, 2) score = decoder_out.cpu().gather(2, index).squeeze(2) # B2 X T2 # mask padded part score = score * mask # decoder_out = decoder_out.view(B, bz, max_len+1, -1) score = torch.sum(score, axis=1) score = torch.reshape(score,(B,bz)) all_scores = ctc_score + 0.1 * score # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) elif args.mode == 'attention_lm_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) # 优化:批量构建hyps_pad,避免嵌套循环 hyps_pad = torch.full((args.batch_size, args.beam_size, max_len), speech2text.asr_model.ignore_id, dtype=torch.int64) hyps_lens = torch.zeros((args.batch_size, args.beam_size), dtype=torch.int32) # 批量填充数据 for k, cand in enumerate(all_hyps): i = k // args.beam_size j = k % args.beam_size l = len(cand) hyps_pad[i, j, :l] = torch.tensor(cand, dtype=torch.int64) hyps_lens[i, j] = l bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_out_lens = encoder_out_lens.repeat(bz) hyps_pad = hyps_pad.view(B2, max_len).cuda() hyps_lens = hyps_lens.view(B2,).cuda() decoder_scores = -speech2text.asr_model.batchify_nll( encoder_out, encoder_out_lens, hyps_pad, hyps_lens, 320 ) decoder_scores = torch.reshape(decoder_scores,(B,bz)).cpu() # 使用完整的ESPnetLanguageModel对象进行语言模型评分 if full_lm_model is not None: try: # 首先清理数据:将ignore_id替换为0(语言模型的padding值) hyps_pad_clean = hyps_pad.clone() hyps_pad_clean[hyps_pad_clean == speech2text.asr_model.ignore_id] = 0 # 使用更小的批量大小避免内存问题 nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad_clean, hyps_lens, 64) except Exception as e: print(f"语言模型评分失败: {e}") # 如果失败,使用零值作为fallback nnlm_nll = torch.zeros_like(hyps_pad) x_lengths = hyps_lens else: # 如果没有语言模型,使用默认值 nnlm_nll = torch.zeros_like(hyps_pad) x_lengths = hyps_lens nnlm_scores = -nnlm_nll.sum(dim=1) nnlm_scores = torch.reshape(nnlm_scores,(B,bz)).cpu() all_scores = ctc_score - 0.05 * decoder_scores + 1.0 * nnlm_scores # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) elif args.mode == 'lm_rescoring': # 优化:预分配内存,避免动态扩展 ctc_score = [] all_hyps = [] max_len = 0 # 预计算最大长度 for hyps in score_hyps: for hyp in hyps: if len(hyp[1]) > max_len: max_len = len(hyp[1]) # 批量处理 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < args.beam_size: hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) ctc_score.append(cur_ctc_score) ctc_score = torch.tensor(ctc_score, dtype=torch.float32) hyps_pad = torch.ones( (args.batch_size, args.beam_size, max_len), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id hyps_lens = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32) k = 0 for i in range(args.batch_size): for j in range(args.beam_size): cand = all_hyps[k] l = len(cand) hyps_pad[i][j][0:l] = torch.tensor(cand) hyps_lens[i][j] = len(cand) k += 1 bz = args.beam_size B,T,F = encoder_out.shape B2=B*bz hyps_pad = hyps_pad.view(B2, max_len).cuda() hyps_lens = hyps_lens.view(B2,).cuda() hyps_pad[hyps_pad == speech2text.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad, hyps_lens, 320) nnlm_scores = -nnlm_nll.sum(dim=1) nnlm_scores = torch.reshape(nnlm_scores,(B,bz)) # 直接在GPU上计算,避免CPU-GPU传输 ctc_score_gpu = ctc_score.cuda() all_scores = ctc_score_gpu + 0.9 * nnlm_scores # FIX ME need tuned best_index = torch.argmax(all_scores, dim=1) best_index = best_index.cpu() # 只在最后传输到CPU best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k: k + args.beam_size][idx] best_sents.append(cur_best_sent) k += args.beam_size hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes) else: raise NotImplementedError print("耗时:",{time.time()-ll}, "fps:", {24/(time.time()-ll)}) for i, key in enumerate(names): content = hyps[i] # print('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) # 记录batch推理时间(不包含torch.compile时间) batch_end_time = time.perf_counter() total_inference_time += batch_end_time - batch_start_time # 计算总时间统计(不包含torch.compile时间) if str(args.gpu) == '0': with open(args.log_file, 'w') as log: log.write(f"Decoding audio {audio_sample_len} secs, cost {total_inference_time} secs (不包含torch.compile时间), RTF: {total_inference_time/audio_sample_len}, process {audio_sample_len/total_inference_time} secs audio per second, decoding args: {args}")