inference_engine.py 12.1 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
#!/usr/bin/env python3
"""
Inference Engine for ASR Pipeline

This module handles ONNX model loading and inference for the ASR pipeline.
It provides both single file and batch processing capabilities.
"""

import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np


try:
    from espnet_onnx import Speech2Text
    from espnet_onnx.asr.beam_search.hyps import Hypothesis
except ImportError:
    Speech2Text = None
    Hypothesis = None


try:
    import soundfile as sf
except ImportError:
    sf = None


class ONNXInferenceEngine:
    """ONNX inference engine for ASR models"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.logger = logging.getLogger(__name__)
        
        # Model configuration
        self.model_dir = config.get('model_dir')
        self.tag_name = config.get('tag_name', 'asr_model')
        self.cache_dir = config.get('cache_dir', './cache')
        self.device = config.get('device', 'cpu')
        self.batch_size = config.get('batch_size', 1)
        self.use_quantized = config.get('use_quantized', False)
        
        # Output configuration
        self.output_dir = Path(config.get('output_dir', './results'))
        
        # Model instance
        self.speech2text = None
        
    def load_model(self):
        """Load the ONNX model"""
        if Speech2Text is None:
            raise ImportError("espnet_onnx is not installed. Please install it first.")
            
        if sf is None:
            raise ImportError("soundfile is not installed. Please install it first.")
            
        self.logger.info(f"Loading ONNX model from: {self.model_dir}")
        
        # Set providers based on device
        if self.device.lower() == 'gpu':
            providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
        else:
            providers = ["CPUExecutionProvider"]
            
        try:
            # import pdb;pdb.set_trace()
            self.speech2text = Speech2Text(
                model_dir=self.model_dir,
                providers=providers,
                # use_quantized=self.use_quantized,
                # cache_dir=self.cache_dir,
            )
            self.logger.info("Model loaded successfully")
            
        except Exception as e:
            self.logger.error(f"Error loading model: {e}")
            raise
            
    def run_inference(self, test_files: Dict[str, str]) -> Dict[str, Dict]:
        """
        Run inference on test files
        
        Args:
            test_files: Dictionary mapping utterance IDs to audio file paths
            
        Returns:
            Dictionary containing inference results for each utterance
        """
        if self.speech2text is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
            
        self.logger.info(f"Running inference on {len(test_files)} files")
        
        results = {}
        total_files = len(test_files)
        
        # Create output directory
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Process files
        for idx, (utt_id, audio_path) in enumerate(test_files.items(), 1):
            if idx==100:
                break
            self.logger.info(f"Processing [{idx}/{total_files}]: {utt_id}")
            # import pdb;pdb.set_trace()
            try:
                # Load audio
                audio, rate = self._load_audio(audio_path)
                
                # Perform inference with timing
                start_time = time.time()
                inference_results = self.speech2text(audio)
                inference_time = time.time() - start_time
                
                # Process results
                result = self._process_inference_result(
                    utt_id, audio, rate, inference_results, inference_time
                )
                
                results[utt_id] = result
                
                # Save individual result
                self._save_individual_result(utt_id, result)
                
                self.logger.info(f"  -> Recognized: {result.get('text', 'N/A')}")
                
            except Exception as e:
                self.logger.error(f"Error processing {utt_id}: {e}")
                results[utt_id] = {
                    'success': False,
                    'error': str(e),
                    'audio_duration': 0,
                    'inference_time': 0
                }
                
        # Save combined results
        self._save_combined_results(results)
        
        self.logger.info("Inference completed")
        return results
        
    def _load_audio(self, audio_path: str) -> Tuple[np.ndarray, int]:
        """Load audio file"""
        try:
            audio, rate = sf.read(audio_path)
            self.logger.debug(f"Loaded audio: {audio_path}, shape: {audio.shape}, rate: {rate}")
            return audio, rate
        except Exception as e:
            self.logger.error(f"Error loading audio file {audio_path}: {e}")
            raise
            
    def _process_inference_result(self, utt_id: str, audio: np.ndarray, rate: int, 
                                inference_results: List, inference_time: float) -> Dict:
        """Process inference results into a structured format"""
        
        # Calculate audio duration
        audio_duration = len(audio) / rate if rate > 0 else 0
        
        # Extract recognition results
        if inference_results:
            # Take the first (best) hypothesis
            text, tokens, token_ids, hyp = inference_results[0]
            
            result = {
                'success': True,
                'text': text,
                'tokens': tokens,
                'token_ids': token_ids,
                'hypothesis_score': hyp.score if hasattr(hyp, 'score') else 0,
                'audio_duration': audio_duration,
                'inference_time': inference_time,
                'rtf': inference_time / audio_duration if audio_duration > 0 else 0
            }
            
            # Add additional hypotheses if available
            if len(inference_results) > 1:
                result['alternative_hypotheses'] = []
                for i, (alt_text, alt_tokens, alt_token_ids, alt_hyp) in enumerate(inference_results[1:], 2):
                    result['alternative_hypotheses'].append({
                        'text': alt_text,
                        'tokens': alt_tokens,
                        'token_ids': alt_token_ids,
                        'score': alt_hyp.score if hasattr(alt_hyp, 'score') else 0
                    })
        else:
            result = {
                'success': False,
                'error': 'No recognition results',
                'audio_duration': audio_duration,
                'inference_time': inference_time
            }
            
        return result
        
    def _save_individual_result(self, utt_id: str, result: Dict):
        """Save individual result to file"""
        output_file = self.output_dir / f"{utt_id}.txt"
        
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(f"# Recognition Result for {utt_id}\n")
            f.write(f"Success: {result['success']}\n")
            
            if result['success']:
                f.write(f"Text: {result['text']}\n")
                f.write(f"Tokens: {' '.join(result['tokens'])}\n")
                f.write(f"Token IDs: {' '.join(map(str, result['token_ids']))}\n")
                f.write(f"Hypothesis Score: {result['hypothesis_score']:.4f}\n")
                f.write(f"Audio Duration: {result['audio_duration']:.2f}s\n")
                f.write(f"Inference Time: {result['inference_time']:.2f}s\n")
                f.write(f"Real-Time Factor: {result['rtf']:.4f}\n")
                
                # Alternative hypotheses
                if 'alternative_hypotheses' in result:
                    f.write("\nAlternative Hypotheses:\n")
                    for i, alt in enumerate(result['alternative_hypotheses'], 1):
                        f.write(f"  {i}. {alt['text']} (Score: {alt['score']:.4f})\n")
            else:
                f.write(f"Error: {result.get('error', 'Unknown error')}\n")
                
    def _save_combined_results(self, results: Dict[str, Dict]):
        """Save all results to a combined file"""
        output_file = self.output_dir / "all_results.txt"
        
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("# ASR Inference Results\n")
            f.write("=" * 50 + "\n\n")
            
            successful = sum(1 for r in results.values() if r.get('success', False))
            failed = len(results) - successful
            
            f.write(f"Summary:\n")
            f.write(f"  Total files: {len(results)}\n")
            f.write(f"  Successful: {successful}\n")
            f.write(f"  Failed: {failed}\n\n")
            
            f.write("Detailed Results:\n")
            f.write("-" * 30 + "\n")
            
            for utt_id, result in results.items():
                f.write(f"\n{utt_id}:\n")
                
                if result['success']:
                    f.write(f"  Text: {result['text']}\n")
                    f.write(f"  Duration: {result['audio_duration']:.2f}s\n")
                    f.write(f"  Inference Time: {result['inference_time']:.2f}s\n")
                    f.write(f"  RTF: {result['rtf']:.4f}\n")
                else:
                    f.write(f"  ERROR: {result.get('error', 'Unknown error')}\n")
                    
    def run_batch_inference(self, audio_batch: List[Tuple[str, np.ndarray]]) -> List[Dict]:
        """
        Run batch inference on multiple audio samples
        
        Args:
            audio_batch: List of (utterance_id, audio_data) tuples
            
        Returns:
            List of inference results
        """
        if self.speech2text is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
            
        self.logger.info(f"Running batch inference on {len(audio_batch)} samples")
        
        results = []
        
        # Process in batches
        for i in range(0, len(audio_batch), self.batch_size):
            batch = audio_batch[i:i + self.batch_size]
            utt_ids = [item[0] for item in batch]
            audio_data = [item[1] for item in batch]
            
            # Convert to numpy array if single sample
            if len(audio_data) == 1:
                audio_data = audio_data[0]
            else:
                # Pad or truncate to same length for batch processing
                max_len = max(len(audio) for audio in audio_data)
                audio_data = np.array([
                    np.pad(audio, (0, max_len - len(audio))) if len(audio) < max_len 
                    else audio[:max_len] for audio in audio_data
                ])
                
            # Perform inference
            start_time = time.time()
            batch_results = self.speech2text(audio_data)
            inference_time = time.time() - start_time
            
            # Process results
            for j, (utt_id, result) in enumerate(zip(utt_ids, batch_results)):
                processed_result = self._process_inference_result(
                    utt_id, audio_data[j] if len(audio_data) > 1 else audio_data, 
                    16000, [result], inference_time / len(batch)
                )
                results.append(processed_result)
                
        return results


if __name__ == '__main__':
    # Test the inference engine
    import logging
    logging.basicConfig(level=logging.INFO)
    
    config = {
        'model_dir': './models',
        'tag_name': 'asr_model',
        'device': 'cpu',
        'output_dir': './test_results'
    }
    
    engine = ONNXInferenceEngine(config)
    
    # Test with dummy data
    test_files = {
        'test_001': './test_audio.wav',  # Replace with actual file
        'test_002': './test_audio2.wav'
    }
    
    try:
        engine.load_model()
        results = engine.run_inference(test_files)
        print(f"Inference completed: {len(results)} results")
    except Exception as e:
        print(f"Error: {e}")