import onnxruntime
import numpy as np
import librosa
import soundfile as sf

class ConformerASR:
    def __init__(self, encoder_path, ctc_path):
        self.encoder = onnxruntime.InferenceSession(encoder_path)
        self.ctc = onnxruntime.InferenceSession(ctc_path)
        self.sr = 16000  # 采样率
        self.n_mels = 80  # mel频带数
        
    def extract_features(self, audio_path):
        """提取log-mel特征"""
        # 加载音频
        audio, fs = sf.read(audio_path)
        
        # 重采样
        if fs != self.sr:
            audio = librosa.resample(audio, orig_sr=fs, target_sr=self.sr)
        
        # 提取mel特征
        mel = librosa.feature.melspectrogram(
            y=audio, sr=self.sr, n_mels=self.n_mels,
            n_fft=400, hop_length=160
        )
        
        # log-mel
        log_mel = librosa.power_to_db(mel)
        
        # 标准化（使用训练时的均值和方差）
        # 这里需要从训练数据获取统计信息
        
        # 转置并添加batch维度
        features = log_mel.T[None, :, :].astype(np.float32)
        
        return features
    
    def transcribe(self, audio_path):
        """语音识别主函数"""
        # 1. 特征提取
        features = self.extract_features(audio_path)
        lengths = np.array([features.shape[1]], dtype=np.int64)
        
        # 2. 编码器推理
        encoder_inputs = {
            'speech': features,
            'speech_lengths': lengths
        }
        encoder_out, _ = self.encoder.run(None, encoder_inputs)
        
        # 3. CTC推理
        ctc_inputs = {'encoder_out': encoder_out}
        ctc_log_probs = self.ctc.run(None, ctc_inputs)[0]
        
        # 4. CTC解码
        text = self.ctc_decode(ctc_log_probs[0])
        
        return text
    
    def ctc_decode(self, ctc_log_probs):
        """CTC贪心解码"""
        # 获取最可能的token
        token_ids = np.argmax(ctc_log_probs, axis=1)
        
        # 移除重复和空白
        decoded = []
        prev = -1
        for token in token_ids:
            if token != prev and token != 0:
                decoded.append(token)
            prev = token
        
        # 转换为文本（需要词汇表）
        # text = ''.join([self.vocab[t] for t in decoded])
        return decoded

# 使用示例
if __name__ == "__main__":
    asr = ConformerASR(
        encoder_path="conformer_transformer.onnx",
        ctc_path="conformer_ctc.onnx"
    )
    
    # 识别音频
    text = asr.transcribe("test.wav")
    print(f"识别结果: {text}")