#!/usr/bin/env python3 """ Inference Engine for ASR Pipeline This module handles ONNX model loading and inference for the ASR pipeline. It provides both single file and batch processing capabilities. """ import logging import time from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np try: from espnet_onnx import Speech2Text from espnet_onnx.asr.beam_search.hyps import Hypothesis except ImportError: Speech2Text = None Hypothesis = None try: import soundfile as sf except ImportError: sf = None class ONNXInferenceEngine: """ONNX inference engine for ASR models""" def __init__(self, config: Dict): self.config = config self.logger = logging.getLogger(__name__) # Model configuration self.model_dir = config.get('model_dir') self.tag_name = config.get('tag_name', 'asr_model') self.cache_dir = config.get('cache_dir', './cache') self.device = config.get('device', 'cpu') self.batch_size = config.get('batch_size', 1) self.use_quantized = config.get('use_quantized', False) # Output configuration self.output_dir = Path(config.get('output_dir', './results')) # Model instance self.speech2text = None def load_model(self): """Load the ONNX model""" if Speech2Text is None: raise ImportError("espnet_onnx is not installed. Please install it first.") if sf is None: raise ImportError("soundfile is not installed. Please install it first.") self.logger.info(f"Loading ONNX model from: {self.model_dir}") # Set providers based on device if self.device.lower() == 'gpu': providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] try: # import pdb;pdb.set_trace() self.speech2text = Speech2Text( model_dir=self.model_dir, providers=providers, # use_quantized=self.use_quantized, # cache_dir=self.cache_dir, ) self.logger.info("Model loaded successfully") except Exception as e: self.logger.error(f"Error loading model: {e}") raise def run_inference(self, test_files: Dict[str, str]) -> Dict[str, Dict]: """ Run inference on test files Args: test_files: Dictionary mapping utterance IDs to audio file paths Returns: Dictionary containing inference results for each utterance """ if self.speech2text is None: raise RuntimeError("Model not loaded. Call load_model() first.") self.logger.info(f"Running inference on {len(test_files)} files") results = {} total_files = len(test_files) # Create output directory self.output_dir.mkdir(parents=True, exist_ok=True) # Process files for idx, (utt_id, audio_path) in enumerate(test_files.items(), 1): if idx==100: break self.logger.info(f"Processing [{idx}/{total_files}]: {utt_id}") # import pdb;pdb.set_trace() try: # Load audio audio, rate = self._load_audio(audio_path) # Perform inference with timing start_time = time.time() inference_results = self.speech2text(audio) inference_time = time.time() - start_time # Process results result = self._process_inference_result( utt_id, audio, rate, inference_results, inference_time ) results[utt_id] = result # Save individual result self._save_individual_result(utt_id, result) self.logger.info(f" -> Recognized: {result.get('text', 'N/A')}") except Exception as e: self.logger.error(f"Error processing {utt_id}: {e}") results[utt_id] = { 'success': False, 'error': str(e), 'audio_duration': 0, 'inference_time': 0 } # Save combined results self._save_combined_results(results) self.logger.info("Inference completed") return results def _load_audio(self, audio_path: str) -> Tuple[np.ndarray, int]: """Load audio file""" try: audio, rate = sf.read(audio_path) self.logger.debug(f"Loaded audio: {audio_path}, shape: {audio.shape}, rate: {rate}") return audio, rate except Exception as e: self.logger.error(f"Error loading audio file {audio_path}: {e}") raise def _process_inference_result(self, utt_id: str, audio: np.ndarray, rate: int, inference_results: List, inference_time: float) -> Dict: """Process inference results into a structured format""" # Calculate audio duration audio_duration = len(audio) / rate if rate > 0 else 0 # Extract recognition results if inference_results: # Take the first (best) hypothesis text, tokens, token_ids, hyp = inference_results[0] result = { 'success': True, 'text': text, 'tokens': tokens, 'token_ids': token_ids, 'hypothesis_score': hyp.score if hasattr(hyp, 'score') else 0, 'audio_duration': audio_duration, 'inference_time': inference_time, 'rtf': inference_time / audio_duration if audio_duration > 0 else 0 } # Add additional hypotheses if available if len(inference_results) > 1: result['alternative_hypotheses'] = [] for i, (alt_text, alt_tokens, alt_token_ids, alt_hyp) in enumerate(inference_results[1:], 2): result['alternative_hypotheses'].append({ 'text': alt_text, 'tokens': alt_tokens, 'token_ids': alt_token_ids, 'score': alt_hyp.score if hasattr(alt_hyp, 'score') else 0 }) else: result = { 'success': False, 'error': 'No recognition results', 'audio_duration': audio_duration, 'inference_time': inference_time } return result def _save_individual_result(self, utt_id: str, result: Dict): """Save individual result to file""" output_file = self.output_dir / f"{utt_id}.txt" with open(output_file, 'w', encoding='utf-8') as f: f.write(f"# Recognition Result for {utt_id}\n") f.write(f"Success: {result['success']}\n") if result['success']: f.write(f"Text: {result['text']}\n") f.write(f"Tokens: {' '.join(result['tokens'])}\n") f.write(f"Token IDs: {' '.join(map(str, result['token_ids']))}\n") f.write(f"Hypothesis Score: {result['hypothesis_score']:.4f}\n") f.write(f"Audio Duration: {result['audio_duration']:.2f}s\n") f.write(f"Inference Time: {result['inference_time']:.2f}s\n") f.write(f"Real-Time Factor: {result['rtf']:.4f}\n") # Alternative hypotheses if 'alternative_hypotheses' in result: f.write("\nAlternative Hypotheses:\n") for i, alt in enumerate(result['alternative_hypotheses'], 1): f.write(f" {i}. {alt['text']} (Score: {alt['score']:.4f})\n") else: f.write(f"Error: {result.get('error', 'Unknown error')}\n") def _save_combined_results(self, results: Dict[str, Dict]): """Save all results to a combined file""" output_file = self.output_dir / "all_results.txt" with open(output_file, 'w', encoding='utf-8') as f: f.write("# ASR Inference Results\n") f.write("=" * 50 + "\n\n") successful = sum(1 for r in results.values() if r.get('success', False)) failed = len(results) - successful f.write(f"Summary:\n") f.write(f" Total files: {len(results)}\n") f.write(f" Successful: {successful}\n") f.write(f" Failed: {failed}\n\n") f.write("Detailed Results:\n") f.write("-" * 30 + "\n") for utt_id, result in results.items(): f.write(f"\n{utt_id}:\n") if result['success']: f.write(f" Text: {result['text']}\n") f.write(f" Duration: {result['audio_duration']:.2f}s\n") f.write(f" Inference Time: {result['inference_time']:.2f}s\n") f.write(f" RTF: {result['rtf']:.4f}\n") else: f.write(f" ERROR: {result.get('error', 'Unknown error')}\n") def run_batch_inference(self, audio_batch: List[Tuple[str, np.ndarray]]) -> List[Dict]: """ Run batch inference on multiple audio samples Args: audio_batch: List of (utterance_id, audio_data) tuples Returns: List of inference results """ if self.speech2text is None: raise RuntimeError("Model not loaded. Call load_model() first.") self.logger.info(f"Running batch inference on {len(audio_batch)} samples") results = [] # Process in batches for i in range(0, len(audio_batch), self.batch_size): batch = audio_batch[i:i + self.batch_size] utt_ids = [item[0] for item in batch] audio_data = [item[1] for item in batch] # Convert to numpy array if single sample if len(audio_data) == 1: audio_data = audio_data[0] else: # Pad or truncate to same length for batch processing max_len = max(len(audio) for audio in audio_data) audio_data = np.array([ np.pad(audio, (0, max_len - len(audio))) if len(audio) < max_len else audio[:max_len] for audio in audio_data ]) # Perform inference start_time = time.time() batch_results = self.speech2text(audio_data) inference_time = time.time() - start_time # Process results for j, (utt_id, result) in enumerate(zip(utt_ids, batch_results)): processed_result = self._process_inference_result( utt_id, audio_data[j] if len(audio_data) > 1 else audio_data, 16000, [result], inference_time / len(batch) ) results.append(processed_result) return results if __name__ == '__main__': # Test the inference engine import logging logging.basicConfig(level=logging.INFO) config = { 'model_dir': './models', 'tag_name': 'asr_model', 'device': 'cpu', 'output_dir': './test_results' } engine = ONNXInferenceEngine(config) # Test with dummy data test_files = { 'test_001': './test_audio.wav', # Replace with actual file 'test_002': './test_audio2.wav' } try: engine.load_model() results = engine.run_inference(test_files) print(f"Inference completed: {len(results)} results") except Exception as e: print(f"Error: {e}")