evaluator.bak.py 10.7 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
#!/usr/bin/env python3
"""
Evaluation Module for ASR Pipeline

This module handles performance evaluation of ASR results, including
Word Error Rate (WER) calculation and other metrics.
"""

import logging
import re
from collections import Counter
from pathlib import Path
from typing import Dict, List, Tuple


try:
    import jiwer
    from jiwer import wer
    WER_AVAILABLE = True
except ImportError:
    WER_AVAILABLE = False
    

try:
    import numpy as np
except ImportError:
    np = None


class ASREvaluator:
    """ASR evaluation class for calculating performance metrics"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.logger = logging.getLogger(__name__)
        self.output_dir = Path(config.get('output_dir', './results'))
        
        if not WER_AVAILABLE:
            self.logger.warning("jiwer not available. WER calculation will be limited.")
            
    def evaluate(self, inference_results: Dict[str, Dict], references: Dict[str, str]) -> Dict:
        """
        Evaluate ASR results against references
        
        Args:
            inference_results: Dictionary of inference results
            references: Dictionary of reference transcriptions
            
        Returns:
            Dictionary containing evaluation metrics
        """
        self.logger.info("Starting evaluation")
        
        # Filter successful results
        successful_results = {
            utt_id: result for utt_id, result in inference_results.items() 
            if result.get('success', False) and utt_id in references
        }
        
        if not successful_results:
            self.logger.warning("No successful results to evaluate")
            return {}
            
        self.logger.info(f"Evaluating {len(successful_results)} successful results")
        
        # Prepare data for evaluation
        hypotheses = []
        refs = []
        utt_ids = []
        
        for utt_id, result in successful_results.items():
            hypotheses.append(result['text'])
            refs.append(references[utt_id])
            utt_ids.append(utt_id)
            
        # Calculate metrics
        metrics = {}
        
        # Word Error Rate
        if WER_AVAILABLE:
            try:
                wer_score = self._calculate_wer(hypotheses, refs)
                metrics['WER'] = wer_score
                self.logger.info(f"Word Error Rate: {wer_score:.4f}")
            except Exception as e:
                self.logger.error(f"Error calculating WER: {e}")
                metrics['WER'] = 'Error'
        else:
            # Fallback to simple character-based comparison
            cer_score = self._calculate_cer(hypotheses, refs)
            metrics['CER'] = cer_score
            self.logger.info(f"Character Error Rate: {cer_score:.4f}")
            
        # Additional metrics
        metrics.update(self._calculate_additional_metrics(hypotheses, refs, utt_ids))
        
        # Save detailed evaluation results
        self._save_evaluation_details(successful_results, references, metrics)
        
        self.logger.info("Evaluation completed")
        return metrics
        
    def _calculate_wer(self, hypotheses: List[str], references: List[str]) -> float:
        """Calculate Word Error Rate using jiwer"""
        # Preprocess text for WER calculation
        def preprocess_text(text):
            # Convert to lowercase and remove punctuation
            text = text.lower()
            text = re.sub(r'[^\w\s]', '', text)
            # Remove extra whitespace
            text = ' '.join(text.split())
            return text
            
        processed_hypotheses = [preprocess_text(h) for h in hypotheses]
        processed_references = [preprocess_text(r) for r in references]
        
        # Calculate WER
        return wer(processed_references, processed_hypotheses)
        
    def _calculate_cer(self, hypotheses: List[str], references: List[str]) -> float:
        """Calculate Character Error Rate (fallback when WER is not available)"""
        total_chars = 0
        total_errors = 0
        
        for hyp, ref in zip(hypotheses, references):
            # Simple character-level comparison
            hyp_chars = list(hyp)
            ref_chars = list(ref)
            
            # Calculate edit distance (simplified)
            # This is a simplified version - for production use a proper edit distance algorithm
            max_len = max(len(hyp_chars), len(ref_chars))
            errors = abs(len(hyp_chars) - len(ref_chars))
            
            for i in range(min(len(hyp_chars), len(ref_chars))):
                if hyp_chars[i] != ref_chars[i]:
                    errors += 1
                    
            total_errors += errors
            total_chars += len(ref_chars)
            
        return total_errors / total_chars if total_chars > 0 else 1.0
        
    def _calculate_additional_metrics(self, hypotheses: List[str], references: List[str], 
                                    utt_ids: List[str]) -> Dict:
        """Calculate additional evaluation metrics"""
        metrics = {}
        
        # Success rate
        total_utterances = len(hypotheses)
        metrics['Total_Utterances'] = total_utterances
        
        # Average hypothesis score
        scores = []
        for hyp in hypotheses:
            # Simple confidence metric based on text length and content
            if hyp.strip():
                # Higher score for longer, non-empty hypotheses
                score = len(hyp.split()) / max(1, len(hyp.split()))
                scores.append(score)
                
        if scores:
            metrics['Avg_Confidence'] = sum(scores) / len(scores)
            
        # Text length statistics
        ref_lengths = [len(ref.split()) for ref in references if ref.strip()]
        hyp_lengths = [len(hyp.split()) for hyp in hypotheses if hyp.strip()]
        
        if ref_lengths:
            metrics['Avg_Ref_Length'] = sum(ref_lengths) / len(ref_lengths)
        if hyp_lengths:
            metrics['Avg_Hyp_Length'] = sum(hyp_lengths) / len(hyp_lengths)
            
        # Length ratio
        if ref_lengths and hyp_lengths:
            length_ratios = [h_len / r_len if r_len > 0 else 0 
                           for h_len, r_len in zip(hyp_lengths, ref_lengths)]
            metrics['Avg_Length_Ratio'] = sum(length_ratios) / len(length_ratios)
            
        return metrics
        
    def _save_evaluation_details(self, results: Dict[str, Dict], references: Dict[str, str], 
                                metrics: Dict):
        """Save detailed evaluation results to file"""
        eval_file = self.output_dir / 'evaluation_details.txt'
        
        with open(eval_file, 'w', encoding='utf-8') as f:
            f.write("ASR Evaluation Details\n")
            f.write("=" * 50 + "\n\n")
            
            # Summary metrics
            f.write("Summary Metrics:\n")
            for metric, value in metrics.items():
                if isinstance(value, float):
                    f.write(f"  {metric}: {value:.4f}\n")
                else:
                    f.write(f"  {metric}: {value}\n")
            f.write("\n")
            
            # Detailed results
            f.write("Detailed Results:\n")
            f.write("-" * 80 + "\n")
            
            for utt_id, result in results.items():
                ref_text = references.get(utt_id, 'N/A')
                hyp_text = result.get('text', 'N/A')
                
                f.write(f"\nUtterance: {utt_id}\n")
                f.write(f"Reference: {ref_text}\n")
                f.write(f"Hypothesis: {hyp_text}\n")
                f.write(f"Audio Duration: {result.get('audio_duration', 0):.2f}s\n")
                f.write(f"Inference Time: {result.get('inference_time', 0):.2f}s\n")
                f.write(f"RTF: {result.get('rtf', 0):.4f}\n")
                
                # Calculate utterance-level metrics
                if ref_text != 'N/A' and hyp_text != 'N/A':
                    if WER_AVAILABLE:
                        try:
                            utt_wer = wer([ref_text.lower()], [hyp_text.lower()])
                            f.write(f"WER: {utt_wer:.4f}\n")
                        except:
                            pass
                            
        self.logger.info(f"Evaluation details saved to: {eval_file}")
        
    def calculate_permutation_free_error(self, results: Dict, num_speakers: int = 2) -> Dict:
        """
        Calculate permutation-free error for multi-speaker ASR
        
        This implements functionality similar to eval_perm_free_error.py
        """
        self.logger.info(f"Calculating permutation-free error for {num_speakers} speakers")
        
        # This is a simplified implementation
        # In a real implementation, you would need the full scoring matrix
        
        metrics = {}
        
        # For now, return basic metrics
        metrics['Permutation_Free_Error'] = 'Not implemented'
        metrics['Num_Speakers'] = num_speakers
        
        return metrics


def calculate_simple_wer(hypothesis: str, reference: str) -> float:
    """
    Calculate simple Word Error Rate without external dependencies
    
    This is a basic implementation for when jiwer is not available
    """
    if not hypothesis or not reference:
        return 1.0
        
    # Split into words
    hyp_words = hypothesis.split()
    ref_words = reference.split()
    
    # Simple word-level comparison
    if len(hyp_words) == 0 or len(ref_words) == 0:
        return 1.0
        
    # Count matches
    hyp_counter = Counter(hyp_words)
    ref_counter = Counter(ref_words)
    
    # Calculate intersection
    matches = sum((hyp_counter & ref_counter).values())
    total_ref_words = len(ref_words)
    
    # WER = (substitutions + deletions + insertions) / total reference words
    errors = total_ref_words - matches + max(0, len(hyp_words) - len(ref_words))
    
    return errors / total_ref_words if total_ref_words > 0 else 1.0


if __name__ == '__main__':
    # Test the evaluator
    import logging
    logging.basicConfig(level=logging.INFO)
    
    config = {
        'output_dir': './test_results'
    }
    
    evaluator = ASREvaluator(config)
    
    # Test data
    inference_results = {
        'utt1': {
            'success': True,
            'text': 'this is a test',
            'audio_duration': 2.5,
            'inference_time': 0.5
        },
        'utt2': {
            'success': True,
            'text': 'another test sentence',
            'audio_duration': 3.0,
            'inference_time': 0.6
        }
    }
    
    references = {
        'utt1': 'this is a test',
        'utt2': 'another test sentence'
    }
    
    metrics = evaluator.evaluate(inference_results, references)
    print(f"Evaluation metrics: {metrics}")