#!/usr/bin/env python3 """ Evaluation Module for ASR Pipeline This module handles performance evaluation of ASR results, including Word Error Rate (WER) calculation and other metrics. """ import logging import re from collections import Counter from pathlib import Path from typing import Dict, List, Tuple try: import jiwer from jiwer import wer WER_AVAILABLE = True except ImportError: WER_AVAILABLE = False try: import numpy as np except ImportError: np = None class ASREvaluator: """ASR evaluation class for calculating performance metrics""" def __init__(self, config: Dict): self.config = config self.logger = logging.getLogger(__name__) self.output_dir = Path(config.get('output_dir', './results')) if not WER_AVAILABLE: self.logger.warning("jiwer not available. WER calculation will be limited.") def evaluate(self, inference_results: Dict[str, Dict], references: Dict[str, str]) -> Dict: """ Evaluate ASR results against references Args: inference_results: Dictionary of inference results references: Dictionary of reference transcriptions Returns: Dictionary containing evaluation metrics """ self.logger.info("Starting evaluation") # Filter successful results successful_results = { utt_id: result for utt_id, result in inference_results.items() if result.get('success', False) and utt_id in references } if not successful_results: self.logger.warning("No successful results to evaluate") return {} self.logger.info(f"Evaluating {len(successful_results)} successful results") # Prepare data for evaluation hypotheses = [] refs = [] utt_ids = [] for utt_id, result in successful_results.items(): hypotheses.append(result['text']) refs.append(references[utt_id]) utt_ids.append(utt_id) # Calculate metrics metrics = {} # Word Error Rate if WER_AVAILABLE: try: wer_score = self._calculate_wer(hypotheses, refs) metrics['WER'] = wer_score self.logger.info(f"Word Error Rate: {wer_score:.4f}") except Exception as e: self.logger.error(f"Error calculating WER: {e}") metrics['WER'] = 'Error' else: # Fallback to simple character-based comparison cer_score = self._calculate_cer(hypotheses, refs) metrics['CER'] = cer_score self.logger.info(f"Character Error Rate: {cer_score:.4f}") # Additional metrics metrics.update(self._calculate_additional_metrics(hypotheses, refs, utt_ids)) # Save detailed evaluation results self._save_evaluation_details(successful_results, references, metrics) self.logger.info("Evaluation completed") return metrics def _calculate_wer(self, hypotheses: List[str], references: List[str]) -> float: """Calculate Word Error Rate using jiwer""" # Preprocess text for WER calculation def preprocess_text(text): # Convert to lowercase and remove punctuation text = text.lower() text = re.sub(r'[^\w\s]', '', text) # Remove extra whitespace text = ' '.join(text.split()) return text processed_hypotheses = [preprocess_text(h) for h in hypotheses] processed_references = [preprocess_text(r) for r in references] # Calculate WER return wer(processed_references, processed_hypotheses) def _calculate_cer(self, hypotheses: List[str], references: List[str]) -> float: """Calculate Character Error Rate (fallback when WER is not available)""" total_chars = 0 total_errors = 0 for hyp, ref in zip(hypotheses, references): # Simple character-level comparison hyp_chars = list(hyp) ref_chars = list(ref) # Calculate edit distance (simplified) # This is a simplified version - for production use a proper edit distance algorithm max_len = max(len(hyp_chars), len(ref_chars)) errors = abs(len(hyp_chars) - len(ref_chars)) for i in range(min(len(hyp_chars), len(ref_chars))): if hyp_chars[i] != ref_chars[i]: errors += 1 total_errors += errors total_chars += len(ref_chars) return total_errors / total_chars if total_chars > 0 else 1.0 def _calculate_additional_metrics(self, hypotheses: List[str], references: List[str], utt_ids: List[str]) -> Dict: """Calculate additional evaluation metrics""" metrics = {} # Success rate total_utterances = len(hypotheses) metrics['Total_Utterances'] = total_utterances # Average hypothesis score scores = [] for hyp in hypotheses: # Simple confidence metric based on text length and content if hyp.strip(): # Higher score for longer, non-empty hypotheses score = len(hyp.split()) / max(1, len(hyp.split())) scores.append(score) if scores: metrics['Avg_Confidence'] = sum(scores) / len(scores) # Text length statistics ref_lengths = [len(ref.split()) for ref in references if ref.strip()] hyp_lengths = [len(hyp.split()) for hyp in hypotheses if hyp.strip()] if ref_lengths: metrics['Avg_Ref_Length'] = sum(ref_lengths) / len(ref_lengths) if hyp_lengths: metrics['Avg_Hyp_Length'] = sum(hyp_lengths) / len(hyp_lengths) # Length ratio if ref_lengths and hyp_lengths: length_ratios = [h_len / r_len if r_len > 0 else 0 for h_len, r_len in zip(hyp_lengths, ref_lengths)] metrics['Avg_Length_Ratio'] = sum(length_ratios) / len(length_ratios) return metrics def _save_evaluation_details(self, results: Dict[str, Dict], references: Dict[str, str], metrics: Dict): """Save detailed evaluation results to file, including hyp.trn and ref.trn in specified format""" eval_file = self.output_dir / 'evaluation_details.txt' # 生成 hyp.trn 和 ref.trn 文件(文本在前,utterance ID 在括号内) hyp_trn_file = self.output_dir / 'hyp.trn' ref_trn_file = self.output_dir / 'ref.trn' # 保存 hyp.trn 文件(模型识别结果) with open(hyp_trn_file, 'w', encoding='utf-8') as f: for utt_id, result in results.items(): if result.get('success', False): hyp_text = result.get('text', '') # 格式:文本在前,utterance ID 在括号内 f.write(f"{hyp_text}\t({utt_id})\n") # 保存 ref.trn 文件(参考转录) with open(ref_trn_file, 'w', encoding='utf-8') as f: for utt_id, ref_text in references.items(): if utt_id in results and results[utt_id].get('success', False): # 去掉文本中的空格,然后格式:文本在前,utterance ID 在括号内 ref_text_no_space = ref_text.replace(" ", "") f.write(f"{ref_text_no_space}\t({utt_id})\n") self.logger.info(f"hyp.trn saved to: {hyp_trn_file}") self.logger.info(f"ref.trn saved to: {ref_trn_file}") # 保存详细的评估报告 with open(eval_file, 'w', encoding='utf-8') as f: f.write("ASR Evaluation Details\n") f.write("=" * 50 + "\n\n") # Summary metrics f.write("Summary Metrics:\n") for metric, value in metrics.items(): if isinstance(value, float): f.write(f" {metric}: {value:.4f}\n") else: f.write(f" {metric}: {value}\n") f.write("\n") # Detailed results f.write("Detailed Results:\n") f.write("-" * 80 + "\n") for utt_id, result in results.items(): ref_text = references.get(utt_id, 'N/A') hyp_text = result.get('text', 'N/A') f.write(f"\nUtterance: {utt_id}\n") f.write(f"Reference: {ref_text}\n") f.write(f"Hypothesis: {hyp_text}\n") f.write(f"Audio Duration: {result.get('audio_duration', 0):.2f}s\n") f.write(f"Inference Time: {result.get('inference_time', 0):.2f}s\n") f.write(f"RTF: {result.get('rtf', 0):.4f}\n") # Calculate utterance-level metrics if ref_text != 'N/A' and hyp_text != 'N/A': if WER_AVAILABLE: try: utt_wer = wer([ref_text.lower()], [hyp_text.lower()]) f.write(f"WER: {utt_wer:.4f}\n") except: pass self.logger.info(f"Evaluation details saved to: {eval_file}") def calculate_permutation_free_error(self, results: Dict, num_speakers: int = 2) -> Dict: """ Calculate permutation-free error for multi-speaker ASR This implements functionality similar to eval_perm_free_error.py """ self.logger.info(f"Calculating permutation-free error for {num_speakers} speakers") # This is a simplified implementation # In a real implementation, you would need the full scoring matrix metrics = {} # For now, return basic metrics metrics['Permutation_Free_Error'] = 'Not implemented' metrics['Num_Speakers'] = num_speakers return metrics def calculate_simple_wer(hypothesis: str, reference: str) -> float: """ Calculate simple Word Error Rate without external dependencies This is a basic implementation for when jiwer is not available """ if not hypothesis or not reference: return 1.0 # Split into words hyp_words = hypothesis.split() ref_words = reference.split() # Simple word-level comparison if len(hyp_words) == 0 or len(ref_words) == 0: return 1.0 # Count matches hyp_counter = Counter(hyp_words) ref_counter = Counter(ref_words) # Calculate intersection matches = sum((hyp_counter & ref_counter).values()) total_ref_words = len(ref_words) # WER = (substitutions + deletions + insertions) / total reference words errors = total_ref_words - matches + max(0, len(hyp_words) - len(ref_words)) return errors / total_ref_words if total_ref_words > 0 else 1.0 if __name__ == '__main__': # Test the evaluator import logging logging.basicConfig(level=logging.INFO) config = { 'output_dir': './test_results' } evaluator = ASREvaluator(config) # Test data inference_results = { 'utt1': { 'success': True, 'text': 'this is a test', 'audio_duration': 2.5, 'inference_time': 0.5 }, 'utt2': { 'success': True, 'text': 'another test sentence', 'audio_duration': 3.0, 'inference_time': 0.6 } } references = { 'utt1': 'this is a test', 'utt2': 'another test sentence' } metrics = evaluator.evaluate(inference_results, references) print(f"Evaluation metrics: {metrics}")