#!/usr/bin/env python3 """ ASR Inference Pipeline for ESPNet ONNX Models This script provides a unified Python pipeline for ASR inference that replaces the shell-based workflow. It handles data preparation, inference, and evaluation in a single streamlined process. Key features: - Simplified configuration management - Integrated data processing and inference - Real-time factor calculation - Automatic evaluation and reporting - Support for AISHELL dataset """ import argparse import logging import os import sys from pathlib import Path from typing import Dict, List, Optional, Tuple from data_processor import AISHELLDataProcessor from inference_engine import ONNXInferenceEngine from evaluator import ASREvaluator from utils.rtf_calculator import RTFCalculator class ASRPipeline: """Main ASR pipeline class that orchestrates the entire workflow""" def __init__(self, config: Dict): self.config = config self.setup_logging() # Initialize components self.data_processor = AISHELLDataProcessor(config) self.inference_engine = ONNXInferenceEngine(config) self.evaluator = ASREvaluator(config) self.rtf_calculator = RTFCalculator(config) def setup_logging(self): """Setup logging configuration""" log_level = getattr(logging, self.config.get('log_level', 'INFO')) logging.basicConfig( level=log_level, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler(self.config.get('log_file', 'asr_pipeline.log')) ] ) self.logger = logging.getLogger(__name__) def run(self): """Execute the complete ASR pipeline""" self.logger.info("Starting ASR Pipeline") # Step 1: Data Preparation if not self.config.get('skip_data_prep', False): self.logger.info("Step 1: Data Preparation") data_info = self.data_processor.prepare_data() self.logger.info(f"Data preparation completed: {len(data_info['test_files'])} test files") else: self.logger.info("Skipping data preparation") data_info = self.data_processor.load_existing_data() # Step 2: Model Loading self.logger.info("Step 2: Model Loading") self.inference_engine.load_model() # Step 3: Inference self.logger.info("Step 3: Running Inference") inference_results = self.inference_engine.run_inference(data_info['test_files']) # Step 4: RTF Calculation self.logger.info("Step 4: Calculating Real-Time Factor") rtf_results = self.rtf_calculator.calculate_rtf(inference_results) # Step 5: Evaluation if not self.config.get('skip_eval', False): self.logger.info("Step 5: Running Evaluation") eval_results = self.evaluator.evaluate(inference_results, data_info['references']) else: self.logger.info("Skipping evaluation") eval_results = {} # Step 6: Generate Report self.logger.info("Step 6: Generating Report") self.generate_report(inference_results, rtf_results, eval_results) self.logger.info("ASR Pipeline completed successfully") def generate_report(self, inference_results: Dict, rtf_results: Dict, eval_results: Dict): """Generate a comprehensive report of the pipeline execution""" report_file = Path(self.config.get('output_dir', './results')) / 'pipeline_report.txt' report_file.parent.mkdir(parents=True, exist_ok=True) with open(report_file, 'w', encoding='utf-8') as f: f.write("ASR Pipeline Execution Report\n") f.write("=" * 50 + "\n\n") # Inference Summary f.write("Inference Summary:\n") f.write(f" Total files processed: {len(inference_results)}\n") f.write(f" Successful inferences: {sum(1 for r in inference_results.values() if r.get('success', False))}\n") f.write(f" Failed inferences: {sum(1 for r in inference_results.values() if not r.get('success', True))}\n\n") # RTF Results f.write("Real-Time Factor (RTF) Analysis:\n") f.write(f" Average RTF: {rtf_results.get('avg_rtf', 'N/A'):.4f}\n") f.write(f" Total audio duration: {rtf_results.get('total_audio_sec', 0):.2f} seconds\n") f.write(f" Total inference time: {rtf_results.get('total_inference_sec', 0):.2f} seconds\n\n") # Evaluation Results if eval_results: f.write("Evaluation Results:\n") for metric, value in eval_results.items(): f.write(f" {metric}: {value}\n") # Detailed Results f.write("\nDetailed Results:\n") f.write("-" * 30 + "\n") for utt_id, result in list(inference_results.items())[:10]: # Show first 10 f.write(f"{utt_id}: {result.get('text', 'N/A')}\n") if len(inference_results) > 10: f.write(f"... and {len(inference_results) - 10} more results\n") self.logger.info(f"Report generated: {report_file}") def get_default_config(): """Get default configuration for the pipeline""" return { # Data configuration 'data_dir': './data', 'dataset': 'aishell', 'aishell_data_dir': '/data/datasets/0', # 现有AISHELL数据集路径 'test_sets': ['test'], # Model configuration 'model_dir': None, 'tag_name': 'asr_model', 'cache_dir': './cache', 'device': 'cpu', 'batch_size': 1, 'use_quantized': False, # Processing configuration 'nj': 4, # Number of parallel jobs 'skip_data_prep': False, 'skip_eval': False, # Output configuration 'output_dir': './results', 'log_level': 'INFO', 'log_file': 'asr_pipeline.log', } def main(): """Main entry point for the ASR pipeline""" parser = argparse.ArgumentParser( description='ASR Inference Pipeline for ESPNet ONNX Models', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Data configuration parser.add_argument('--data-dir', type=str, default='./data', help='Directory for dataset storage') parser.add_argument('--dataset', type=str, default='aishell', choices=['aishell'], help='Dataset to use') parser.add_argument('--test-sets', type=str, nargs='+', default=['test'], help='Test sets to evaluate') # Model configuration parser.add_argument('--model-dir', type=str, required=True, help='Directory containing ONNX model files') parser.add_argument('--tag-name', type=str, default='asr_model', help='Model tag name') parser.add_argument('--cache-dir', type=str, default='./cache', help='Cache directory for models') parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'gpu'], help='Device for inference') parser.add_argument('--batch-size', type=int, default=1, help='Batch size for inference') parser.add_argument('--use-quantized', action='store_true', help='Use quantized models') # Processing configuration parser.add_argument('--nj', type=int, default=4, help='Number of parallel jobs') parser.add_argument('--skip-data-prep', action='store_true', help='Skip data preparation steps') parser.add_argument('--skip-eval', action='store_true', help='Skip evaluation steps') # Output configuration parser.add_argument('--output-dir', type=str, default='./results', help='Output directory for results') parser.add_argument('--log-level', type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], help='Logging level') args = parser.parse_args() # Convert args to config dict config = vars(args) # Run pipeline pipeline = ASRPipeline(config) pipeline.run() if __name__ == '__main__': main()