#!/usr/bin/env python3 """ ONNX Inference Pipeline for AISHELL Dataset This script implements a complete ONNX inference pipeline similar to asr_inference.sh, but using Python instead of shell scripts for easier maintenance and customization. Features: - Data loading and preparation - ONNX model inference - Batch processing support - RTF (Real Time Factor) calculation - WER (Word Error Rate) evaluation - Parallel processing Usage: python asr_inference_python.py --onnx_exp exp/conformer_onnx --test_sets "test dev" python asr_inference_python.py --onnx_exp exp/conformer_onnx --batch_size 4 --use_quantized """ import argparse import logging import os import sys import time from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import soundfile as sf from multiprocessing import Pool, cpu_count # Import espnet_onnx modules try: from espnet_onnx import Speech2Text except ImportError: print("Error: espnet_onnx is not installed. Please install it first.") sys.exit(1) # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class ASRInferencePipeline: """Complete ONNX inference pipeline for ASR""" def __init__(self, args): self.args = args self.setup_directories() def setup_directories(self): """Set up directory structure""" self.onnx_exp = Path(self.args.onnx_exp) if not self.onnx_exp.exists(): raise FileNotFoundError(f"ONNX experiment directory not found: {self.onnx_exp}") # Create inference directory inference_tag = "decode_onnx" if self.args.use_quantized: inference_tag += "_quantized" inference_tag += f"_batch{self.args.batch_size}" self.inference_dir = self.onnx_exp / inference_tag self.inference_dir.mkdir(exist_ok=True) logger.info(f"Inference directory: {self.inference_dir}") def load_data(self, test_set): """Load data from directory structure or wav.scp file""" data_dir = Path(self.args.data_dir) / test_set wav_data = [] text_data = {} utt2spk = {} # First try to load from standard Kaldi format (wav.scp, text, utt2spk) wav_scp_path = data_dir / "wav.scp" if wav_scp_path.exists(): logger.info(f"Loading data from standard Kaldi format: {wav_scp_path}") # Load wav.scp with open(wav_scp_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split() if len(parts) < 2: continue utt_id, audio_path = parts[0], ' '.join(parts[1:]) wav_data.append((utt_id, audio_path)) # Load text text_path = data_dir / "text" if text_path.exists(): with open(text_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split() if len(parts) < 2: continue utt_id, text = parts[0], ' '.join(parts[1:]) text_data[utt_id] = text # Load utt2spk utt2spk_path = data_dir / "utt2spk" if utt2spk_path.exists(): with open(utt2spk_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split() if len(parts) < 2: continue utt_id, spk_id = parts[0], parts[1] utt2spk[utt_id] = spk_id else: # Try to load from directory structure (speaker directories containing wav files) logger.info(f"Loading data from directory structure: {data_dir}") # Check if data_dir exists if not data_dir.exists(): raise FileNotFoundError(f"Data directory not found: {data_dir}") # Look for speaker directories (like S0724) for speaker_dir in data_dir.iterdir(): if speaker_dir.is_dir() and speaker_dir.name.startswith('S'): speaker_id = speaker_dir.name logger.info(f"Found speaker directory: {speaker_id}") # Look for wav files in speaker directory for wav_file in speaker_dir.glob('*.wav'): if wav_file.is_file(): # Create utt_id from speaker_id and wav filename utt_id = f"{speaker_id}_{wav_file.stem}" audio_path = str(wav_file) wav_data.append((utt_id, audio_path)) utt2spk[utt_id] = speaker_id logger.debug(f"Added utterance: {utt_id} -> {audio_path}") if not wav_data: # Try one more approach: look for wav files directly in test_set directory logger.info(f"Looking for wav files directly in: {data_dir}") for wav_file in data_dir.glob('*.wav'): if wav_file.is_file(): utt_id = wav_file.stem audio_path = str(wav_file) wav_data.append((utt_id, audio_path)) utt2spk[utt_id] = "unknown" logger.debug(f"Added utterance: {utt_id} -> {audio_path}") if not wav_data: raise FileNotFoundError(f"No audio files found in: {data_dir}\n" + "Please check if the directory contains wav files or a wav.scp file") logger.info(f"Loaded {len(wav_data)} utterances from {test_set}") return wav_data, text_data, utt2spk def split_data(self, wav_data, num_jobs): """Split data into chunks for parallel processing""" chunk_size = (len(wav_data) + num_jobs - 1) // num_jobs chunks = [] for i in range(num_jobs): start = i * chunk_size end = min((i + 1) * chunk_size, len(wav_data)) if start < end: chunks.append(wav_data[start:end]) return chunks def initialize_model(self): """Initialize ONNX model""" try: # Use espnet_onnx's Speech2Text class from espnet_onnx import Speech2Text providers = ['CPUExecutionProvider'] if self.args.device == 'gpu': providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] model = Speech2Text( # tag_name=str(self.onnx_exp.name), model_dir=str(self.onnx_exp), providers=providers, use_quantized=self.args.use_quantized ) logger.info("ONNX model initialized successfully") return model except Exception as e: logger.error(f"Error initializing model: {e}") raise def process_chunk(self, chunk, onnx_exp, use_quantized, device, test_set, job_id): """Process a chunk of data""" results = {} processing_times = {} try: # Initialize model in this process from espnet_onnx import Speech2Text providers = ['CPUExecutionProvider'] if device == 'gpu': providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] model = Speech2Text( tag_name=str(onnx_exp.name), model_dir=str(onnx_exp), providers=providers, use_quantized=use_quantized ) logger.info(f"Model initialized in process {job_id}") except Exception as e: logger.error(f"Error initializing model in process {job_id}: {e}") # Return empty results for utt_id, _ in chunk: results[utt_id] = [] processing_times[utt_id] = {'error': f'Model initialization failed: {e}'} return results, processing_times for utt_id, audio_path in chunk: try: # Load audio start_time = time.time() audio, rate = sf.read(audio_path) audio_load_time = time.time() - start_time # Perform inference infer_start = time.time() model_results = model(audio) infer_time = time.time() - infer_start # Store results results[utt_id] = model_results processing_times[utt_id] = { 'total': time.time() - start_time, 'audio_load': audio_load_time, 'inference': infer_time, 'audio_length': len(audio) / rate } if job_id == 0 and len(results) % 10 == 0: logger.info(f"Processed {len(results)} utterances in job {job_id}") except Exception as e: logger.error(f"Error processing {utt_id}: {e}") results[utt_id] = [] processing_times[utt_id] = {'error': str(e)} return results, processing_times def run_inference(self, test_set, wav_data): """Run inference on test set""" test_dir = self.inference_dir / test_set test_dir.mkdir(exist_ok=True, parents=True) log_dir = test_dir / "logdir" log_dir.mkdir(exist_ok=True) # Split data num_jobs = min(self.args.inference_nj, len(wav_data)) chunks = self.split_data(wav_data, num_jobs) logger.info(f"Processing {test_set} with {num_jobs} parallel jobs") # Run parallel processing results = {} processing_times = {} if num_jobs > 1: # Use multiprocessing with Pool(num_jobs) as pool: tasks = [] for i, chunk in enumerate(chunks): task = pool.apply_async( self.process_chunk, (chunk, self.onnx_exp, self.args.use_quantized, self.args.device, test_set, i) ) tasks.append(task) # Collect results for task in tasks: chunk_results, chunk_times = task.get() results.update(chunk_results) processing_times.update(chunk_times) else: # Single process - initialize model here try: from espnet_onnx import Speech2Text providers = ['CPUExecutionProvider'] if self.args.device == 'gpu': providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] model = Speech2Text( tag_name=str(self.onnx_exp.name), model_dir=str(self.onnx_exp), providers=providers, use_quantized=self.args.use_quantized ) logger.info("Model initialized in main process") # Process chunk for utt_id, audio_path in chunks[0]: try: # Load audio start_time = time.time() audio, rate = sf.read(audio_path) audio_load_time = time.time() - start_time # Perform inference infer_start = time.time() model_results = model(audio) infer_time = time.time() - infer_start # Store results results[utt_id] = model_results processing_times[utt_id] = { 'total': time.time() - start_time, 'audio_load': audio_load_time, 'inference': infer_time, 'audio_length': len(audio) / rate } if len(results) % 10 == 0: logger.info(f"Processed {len(results)} utterances") except Exception as e: logger.error(f"Error processing {utt_id}: {e}") results[utt_id] = [] processing_times[utt_id] = {'error': str(e)} except Exception as e: logger.error(f"Error initializing model: {e}") # Return empty results for utt_id, _ in chunks[0]: results[utt_id] = [] processing_times[utt_id] = {'error': f'Model initialization failed: {e}'} # Save processing times for RTF calculation times_path = log_dir / "processing_times.json" import json with open(times_path, 'w', encoding='utf-8') as f: json.dump(processing_times, f, indent=2, ensure_ascii=False) logger.info(f"Inference completed for {test_set}") return results, processing_times def save_results(self, test_set, results): """Save inference results""" test_dir = self.inference_dir / test_set recog_dir = test_dir / "1best_recog" recog_dir.mkdir(exist_ok=True) # Save text results text_path = recog_dir / "text" token_path = recog_dir / "token" score_path = recog_dir / "score" with open(text_path, 'w', encoding='utf-8') as f_text, \ open(token_path, 'w', encoding='utf-8') as f_token, \ open(score_path, 'w', encoding='utf-8') as f_score: for utt_id, model_results in sorted(results.items()): if not model_results: f_text.write(f"{utt_id} \n") continue # Get first result text, tokens, token_ids, hyp = model_results[0] # Write results f_text.write(f"{utt_id} {text}\n") f_token.write(f"{utt_id} {' '.join(tokens)}\n") f_score.write(f"{utt_id} {hyp.score}\n") logger.info(f"Results saved for {test_set}") def calculate_rtf(self, test_set, processing_times): """Calculate Real Time Factor""" test_dir = self.inference_dir / test_set log_dir = test_dir / "logdir" # Calculate RTF total_audio_time = 0 total_processing_time = 0 valid_utterances = 0 for utt_id, times in processing_times.items(): if 'error' in times: continue if 'audio_length' in times and 'total' in times: total_audio_time += times['audio_length'] total_processing_time += times['total'] valid_utterances += 1 if valid_utterances > 0: rtf = total_processing_time / total_audio_time avg_audio_time = total_audio_time / valid_utterances avg_processing_time = total_processing_time / valid_utterances rtf_results = { 'rtf': rtf, 'total_audio_time': total_audio_time, 'total_processing_time': total_processing_time, 'valid_utterances': valid_utterances, 'avg_audio_time': avg_audio_time, 'avg_processing_time': avg_processing_time } # Save RTF results rtf_path = log_dir / "rtf_results.json" import json with open(rtf_path, 'w', encoding='utf-8') as f: json.dump(rtf_results, f, indent=2) logger.info(f"RTF for {test_set}: {rtf:.4f}") logger.info(f"Average audio length: {avg_audio_time:.2f}s") logger.info(f"Average processing time: {avg_processing_time:.2f}s") return rtf_results else: logger.warning(f"No valid utterances for RTF calculation in {test_set}") return None def calculate_wer(self, test_set, results, text_data): """Calculate Word Error Rate""" test_dir = self.inference_dir / test_set score_dir = test_dir / "score" score_dir.mkdir(exist_ok=True) # Prepare reference and hypothesis references = [] hypotheses = [] valid_utterances = 0 for utt_id, model_results in results.items(): if utt_id not in text_data: continue if not model_results: continue # Get reference and hypothesis reference = text_data[utt_id] hypothesis = model_results[0][0] if model_results else "" if reference and hypothesis: references.append(reference) hypotheses.append(hypothesis) valid_utterances += 1 if valid_utterances == 0: logger.warning(f"No valid utterances for WER calculation in {test_set}") return None # Calculate WER try: # Use jiwer if available import jiwer wer = jiwer.wer(references, hypotheses) cer = jiwer.cer(references, hypotheses) wer_results = { 'wer': wer * 100, 'cer': cer * 100, 'num_utterances': valid_utterances, 'references': references, 'hypotheses': hypotheses } # Save WER results wer_path = score_dir / "wer_results.json" import json with open(wer_path, 'w', encoding='utf-8') as f: json.dump(wer_results, f, indent=2, ensure_ascii=False) # Save detailed results detail_path = score_dir / "wer_details.txt" with open(detail_path, 'w', encoding='utf-8') as f: f.write(f"WER: {wer*100:.2f}%\n") f.write(f"CER: {cer*100:.2f}%\n") f.write(f"Number of utterances: {valid_utterances}\n\n") for i, (ref, hyp) in enumerate(zip(references, hypotheses)): f.write(f"Utterance {i+1}:\n") f.write(f"Reference: {ref}\n") f.write(f"Hypothesis: {hyp}\n\n") logger.info(f"WER for {test_set}: {wer*100:.2f}%") logger.info(f"CER for {test_set}: {cer*100:.2f}%") return wer_results except ImportError: logger.warning("jiwer not installed, skipping WER calculation") logger.warning("Install with: pip install jiwer") return None except Exception as e: logger.error(f"Error calculating WER: {e}") return None def process_test_set(self, test_set): """Process a single test set""" logger.info(f"Processing test set: {test_set}") # Load data wav_data, text_data, utt2spk = self.load_data(test_set) if not wav_data: logger.warning(f"No data found for {test_set}") return # Run inference results, processing_times = self.run_inference(test_set, wav_data) # Save results self.save_results(test_set, results) # Calculate RTF self.calculate_rtf(test_set, processing_times) # Calculate WER if text_data: self.calculate_wer(test_set, results, text_data) else: logger.warning(f"No reference text found for WER calculation in {test_set}") logger.info(f"Completed processing {test_set}") def run(self): """Run the complete pipeline""" logger.info("Starting ONNX inference pipeline") start_time = time.time() # Process each test set for test_set in self.args.test_sets.split(): self.process_test_set(test_set) # Print summary total_time = time.time() - start_time logger.info(f"\n=== Pipeline Summary ===") logger.info(f"Total processing time: {total_time:.2f} seconds") logger.info(f"Test sets processed: {self.args.test_sets}") logger.info(f"ONNX experiment: {self.args.onnx_exp}") logger.info(f"Batch size: {self.args.batch_size}") logger.info(f"Device: {self.args.device}") logger.info(f"Parallel jobs: {self.args.inference_nj}") # Print detailed results for test_set in self.args.test_sets.split(): test_dir = self.inference_dir / test_set # Print WER results wer_path = test_dir / "score" / "wer_results.json" if wer_path.exists(): import json with open(wer_path, 'r', encoding='utf-8') as f: wer_results = json.load(f) logger.info(f"\n=== {test_set} WER Results ===") logger.info(f"WER: {wer_results.get('wer', 'N/A'):.2f}%") logger.info(f"CER: {wer_results.get('cer', 'N/A'):.2f}%") logger.info(f"Utterances: {wer_results.get('num_utterances', 0)}") # Print RTF results rtf_path = test_dir / "logdir" / "rtf_results.json" if rtf_path.exists(): import json with open(rtf_path, 'r', encoding='utf-8') as f: rtf_results = json.load(f) logger.info(f"\n=== {test_set} RTF Results ===") logger.info(f"RTF: {rtf_results.get('rtf', 'N/A'):.4f}") logger.info(f"Total audio time: {rtf_results.get('total_audio_time', 0):.2f}s") logger.info(f"Total processing time: {rtf_results.get('total_processing_time', 0):.2f}s") logger.info("\nPipeline completed successfully!") def main(): """Main function""" import argparse parser = argparse.ArgumentParser( description="ONNX Inference Pipeline for AISHELL Dataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Required arguments parser.add_argument( "--onnx_exp", type=str, required=True, help="ONNX experiment directory" ) # Data arguments parser.add_argument( "--data_dir", type=str, default="./data", help="Data directory containing test sets" ) parser.add_argument( "--test_sets", type=str, default="test", help="Test sets to process (space-separated)" ) # Inference arguments parser.add_argument( "--batch_size", type=int, default=1, help="Batch size for inference" ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "gpu"], help="Device to use for inference" ) parser.add_argument( "--inference_nj", type=int, default=4, help="Number of parallel jobs for inference" ) parser.add_argument( "--use_quantized", action="store_true", help="Use quantized ONNX models" ) # Logging parser.add_argument( "--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level" ) args = parser.parse_args() # Set logging level logging.getLogger().setLevel(args.log_level) # Run pipeline try: pipeline = ASRInferencePipeline(args) pipeline.run() except Exception as e: logger.error(f"Error running pipeline: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()