asr_pipeline.py 8.44 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python3
"""
ASR Inference Pipeline for ESPNet ONNX Models

This script provides a unified Python pipeline for ASR inference that replaces
the shell-based workflow. It handles data preparation, inference, and evaluation
in a single streamlined process.

Key features:
- Simplified configuration management
- Integrated data processing and inference
- Real-time factor calculation
- Automatic evaluation and reporting
- Support for AISHELL dataset
"""

import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from data_processor import AISHELLDataProcessor
from inference_engine import ONNXInferenceEngine
from evaluator import ASREvaluator
from utils.rtf_calculator import RTFCalculator


class ASRPipeline:
    """Main ASR pipeline class that orchestrates the entire workflow"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.setup_logging()
        
        # Initialize components
        self.data_processor = AISHELLDataProcessor(config)
        self.inference_engine = ONNXInferenceEngine(config)
        self.evaluator = ASREvaluator(config)
        self.rtf_calculator = RTFCalculator(config)
        
    def setup_logging(self):
        """Setup logging configuration"""
        log_level = getattr(logging, self.config.get('log_level', 'INFO'))
        logging.basicConfig(
            level=log_level,
            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            handlers=[
                logging.StreamHandler(sys.stdout),
                logging.FileHandler(self.config.get('log_file', 'asr_pipeline.log'))
            ]
        )
        self.logger = logging.getLogger(__name__)
        
    def run(self):
        """Execute the complete ASR pipeline"""
        self.logger.info("Starting ASR Pipeline")
        
        # Step 1: Data Preparation
        if not self.config.get('skip_data_prep', False):
            self.logger.info("Step 1: Data Preparation")
            data_info = self.data_processor.prepare_data()
            self.logger.info(f"Data preparation completed: {len(data_info['test_files'])} test files")
        else:
            self.logger.info("Skipping data preparation")
            data_info = self.data_processor.load_existing_data()
            
        # Step 2: Model Loading
        self.logger.info("Step 2: Model Loading")
        self.inference_engine.load_model()
        
        # Step 3: Inference
        self.logger.info("Step 3: Running Inference")
        inference_results = self.inference_engine.run_inference(data_info['test_files'])
        
        # Step 4: RTF Calculation
        self.logger.info("Step 4: Calculating Real-Time Factor")
        rtf_results = self.rtf_calculator.calculate_rtf(inference_results)
        
        # Step 5: Evaluation
        if not self.config.get('skip_eval', False):
            self.logger.info("Step 5: Running Evaluation")
            eval_results = self.evaluator.evaluate(inference_results, data_info['references'])
        else:
            self.logger.info("Skipping evaluation")
            eval_results = {}
            
        # Step 6: Generate Report
        self.logger.info("Step 6: Generating Report")
        self.generate_report(inference_results, rtf_results, eval_results)
        
        self.logger.info("ASR Pipeline completed successfully")
        
    def generate_report(self, inference_results: Dict, rtf_results: Dict, eval_results: Dict):
        """Generate a comprehensive report of the pipeline execution"""
        report_file = Path(self.config.get('output_dir', './results')) / 'pipeline_report.txt'
        report_file.parent.mkdir(parents=True, exist_ok=True)
        
        with open(report_file, 'w', encoding='utf-8') as f:
            f.write("ASR Pipeline Execution Report\n")
            f.write("=" * 50 + "\n\n")
            
            # Inference Summary
            f.write("Inference Summary:\n")
            f.write(f"  Total files processed: {len(inference_results)}\n")
            f.write(f"  Successful inferences: {sum(1 for r in inference_results.values() if r.get('success', False))}\n")
            f.write(f"  Failed inferences: {sum(1 for r in inference_results.values() if not r.get('success', True))}\n\n")
            
            # RTF Results
            f.write("Real-Time Factor (RTF) Analysis:\n")
            f.write(f"  Average RTF: {rtf_results.get('avg_rtf', 'N/A'):.4f}\n")
            f.write(f"  Total audio duration: {rtf_results.get('total_audio_sec', 0):.2f} seconds\n")
            f.write(f"  Total inference time: {rtf_results.get('total_inference_sec', 0):.2f} seconds\n\n")
            
            # Evaluation Results
            if eval_results:
                f.write("Evaluation Results:\n")
                for metric, value in eval_results.items():
                    f.write(f"  {metric}: {value}\n")
            
            # Detailed Results
            f.write("\nDetailed Results:\n")
            f.write("-" * 30 + "\n")
            for utt_id, result in list(inference_results.items())[:10]:  # Show first 10
                f.write(f"{utt_id}: {result.get('text', 'N/A')}\n")
            
            if len(inference_results) > 10:
                f.write(f"... and {len(inference_results) - 10} more results\n")
                
        self.logger.info(f"Report generated: {report_file}")


def get_default_config():
    """Get default configuration for the pipeline"""
    return {
        # Data configuration
        'data_dir': './data',
        'dataset': 'aishell',
        'aishell_data_dir': '/data/datasets/0',  # 现有AISHELL数据集路径
        'test_sets': ['test'],
        
        # Model configuration
        'model_dir': None,
        'tag_name': 'asr_model',
        'cache_dir': './cache',
        'device': 'cpu',
        'batch_size': 1,
        'use_quantized': False,
        
        # Processing configuration
        'nj': 4,  # Number of parallel jobs
        'skip_data_prep': False,
        'skip_eval': False,
        
        # Output configuration
        'output_dir': './results',
        'log_level': 'INFO',
        'log_file': 'asr_pipeline.log',
    }


def main():
    """Main entry point for the ASR pipeline"""
    parser = argparse.ArgumentParser(
        description='ASR Inference Pipeline for ESPNet ONNX Models',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Data configuration
    parser.add_argument('--data-dir', type=str, default='./data', 
                       help='Directory for dataset storage')
    parser.add_argument('--dataset', type=str, default='aishell', 
                       choices=['aishell'], help='Dataset to use')
    parser.add_argument('--test-sets', type=str, nargs='+', default=['test'],
                       help='Test sets to evaluate')
    
    # Model configuration
    parser.add_argument('--model-dir', type=str, required=True,
                       help='Directory containing ONNX model files')
    parser.add_argument('--tag-name', type=str, default='asr_model',
                       help='Model tag name')
    parser.add_argument('--cache-dir', type=str, default='./cache',
                       help='Cache directory for models')
    parser.add_argument('--device', type=str, default='cpu', 
                       choices=['cpu', 'gpu'], help='Device for inference')
    parser.add_argument('--batch-size', type=int, default=1,
                       help='Batch size for inference')
    parser.add_argument('--use-quantized', action='store_true',
                       help='Use quantized models')
    
    # Processing configuration
    parser.add_argument('--nj', type=int, default=4,
                       help='Number of parallel jobs')
    parser.add_argument('--skip-data-prep', action='store_true',
                       help='Skip data preparation steps')
    parser.add_argument('--skip-eval', action='store_true',
                       help='Skip evaluation steps')
    
    # Output configuration
    parser.add_argument('--output-dir', type=str, default='./results',
                       help='Output directory for results')
    parser.add_argument('--log-level', type=str, default='INFO',
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       help='Logging level')
    
    args = parser.parse_args()
    
    # Convert args to config dict
    config = vars(args)
    
    # Run pipeline
    pipeline = ASRPipeline(config)
    pipeline.run()


if __name__ == '__main__':
    main()