#!/usr/bin/env python3 """ Data Processing Module for ASR Pipeline This module handles data preparation and processing for the ASR pipeline, including downloading datasets, creating data splits, and preparing input files for inference. """ import logging import os import shutil import subprocess import urllib.request import tarfile from pathlib import Path from typing import Dict, List, Tuple data_tag = ["dev"] class DataProcessor: """Base class for data processing""" def __init__(self, config: Dict): self.config = config self.logger = logging.getLogger(__name__) self.data_dir = Path(config.get('data_dir', './data')) def prepare_data(self) -> Dict: """Prepare data for inference - to be implemented by subclasses""" raise NotImplementedError def load_existing_data(self) -> Dict: """Load existing prepared data""" raise NotImplementedError class AISHELLDataProcessor(DataProcessor): """Data processor for AISHELL dataset""" def __init__(self, config: Dict): super().__init__(config) # 使用现有的数据目录结构 self.aishell_dir = Path(config.get('aishell_data_dir', '/data/datasets/0')) self.raw_data_dir = self.aishell_dir / 'data_aishell' self.processed_data_dir = self.data_dir / 'processed' def prepare_data(self) -> Dict: """Prepare AISHELL dataset for inference using existing data""" self.logger.info("Preparing AISHELL dataset from existing data") # 检查现有数据是否存在 if not self._check_data_exists(): raise FileNotFoundError(f"AISHELL dataset not found at {self.raw_data_dir}") self.logger.info("AISHELL dataset found, processing data") # 创建处理后的数据目录 self.processed_data_dir.mkdir(parents=True, exist_ok=True) # 处理数据 data_info = self._process_aishell_data() return data_info def load_existing_data(self) -> Dict: """Load existing prepared AISHELL data""" self.logger.info("Loading existing AISHELL data") # 首先检查是否有处理后的数据 if self._check_processed_data_exists(): return self._load_processed_data() else: # 如果没有处理后的数据,直接从原始数据加载 self.logger.info("No processed data found, loading from raw data") return self.prepare_data() def _check_data_exists(self) -> bool: """Check if AISHELL dataset exists at the specified location""" required_files = [ self.raw_data_dir / 'wav', self.raw_data_dir / 'transcript' / 'aishell_transcript_v0.8.txt' ] return all(f.exists() for f in required_files) def _check_processed_data_exists(self) -> bool: """Check if processed data exists""" test_sets = self.config.get('test_sets', data_tag) required_files = [] for test_set in test_sets: required_files.extend([ self.processed_data_dir / test_set / 'wav.scp', self.processed_data_dir / test_set / 'text' ]) return all(f.exists() for f in required_files) def _process_aishell_data(self) -> Dict: """Process AISHELL data into Kaldi-style format using existing dataset structure""" test_sets = self.config.get('test_sets', data_tag) data_info = {'test_files': {}, 'references': {}} # 加载转录文件 transcript_file = self.raw_data_dir / 'transcript' / 'aishell_transcript_v0.8.txt' if not transcript_file.exists(): raise FileNotFoundError(f"Transcript file not found: {transcript_file}") # 解析转录文件 transcript_dict = self._parse_transcript_file(transcript_file) for test_set in test_sets: self.logger.info(f"Processing {test_set} set") # 创建输出目录 output_dir = self.processed_data_dir / test_set output_dir.mkdir(parents=True, exist_ok=True) # 根据数据集类型设置音频文件路径 audio_dir = self.raw_data_dir / 'wav' / test_set if not audio_dir.exists(): self.logger.warning(f"Audio directory not found: {audio_dir}") continue # 查找音频文件 audio_files = list(audio_dir.rglob('*.wav')) self.logger.info(f"Found {len(audio_files)} audio files for {test_set}") # 创建 wav.scp 和 text 文件 wav_scp_file = output_dir / 'wav.scp' text_file = output_dir / 'text' with open(wav_scp_file, 'w', encoding='utf-8') as wav_f, \ open(text_file, 'w', encoding='utf-8') as text_f: for audio_file in audio_files: # 从文件名提取 utterance ID # AISHELL 的文件名格式如: BAC009S0724W0121.wav utt_id = audio_file.stem # 写入 wav.scp wav_f.write(f"{utt_id} {audio_file.resolve()}\n") # 从转录字典中获取转录文本 transcription = transcript_dict.get(utt_id, "") if not transcription: self.logger.warning(f"No transcription found for {utt_id}") transcription = "[NO_TRANSCRIPTION]" # 写入 text 文件 text_f.write(f"{utt_id} {transcription}\n") # 存储文件信息 data_info['test_files'][utt_id] = str(audio_file.resolve()) data_info['references'][utt_id] = transcription self.logger.info(f"Created {wav_scp_file} and {text_file}") return data_info def _parse_transcript_file(self, transcript_file: Path) -> Dict[str, str]: """Parse AISHELL transcript file into dictionary""" transcript_dict = {} try: with open(transcript_file, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue # AISHELL 转录文件格式: 文件名 转录文本 # 例如: BAC009S0724W0121 而 对 楼市 成交 抑制 最 明显 的 限 购 parts = line.split(maxsplit=1) if len(parts) == 2: utt_id, transcription = parts transcript_dict[utt_id] = transcription else: self.logger.warning(f"Invalid line in transcript file: {line}") self.logger.info(f"Parsed {len(transcript_dict)} transcriptions") except Exception as e: self.logger.error(f"Error parsing transcript file: {e}") raise return transcript_dict def _load_processed_data(self) -> Dict: """Load processed data from files""" test_sets = self.config.get('test_sets', data_tag) data_info = {'test_files': {}, 'references': {}} for test_set in test_sets: wav_scp_file = self.processed_data_dir / test_set / 'wav.scp' text_file = self.processed_data_dir / test_set / 'text' if not wav_scp_file.exists() or not text_file.exists(): self.logger.warning(f"Missing files for {test_set} set") continue # Load wav.scp with open(wav_scp_file, 'r', encoding='utf-8') as f: for line in f: utt_id, audio_path = line.strip().split(maxsplit=1) data_info['test_files'][utt_id] = audio_path # Load text with open(text_file, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split(maxsplit=1) if len(parts) == 2: utt_id, transcription = parts data_info['references'][utt_id] = transcription self.logger.info(f"Loaded {len(data_info['test_files'])} test files") return data_info def split_data_for_parallel_processing(data_info: Dict, nj: int) -> List[Dict]: """ Split data for parallel processing Args: data_info: Dictionary containing test files and references nj: Number of parallel jobs Returns: List of data_info dictionaries for each job """ if nj <= 1: return [data_info] # Split data evenly test_files = list(data_info['test_files'].items()) chunk_size = len(test_files) // nj splits = [] for i in range(nj): start_idx = i * chunk_size end_idx = start_idx + chunk_size if i < nj - 1 else len(test_files) split_files = dict(test_files[start_idx:end_idx]) split_references = {k: data_info['references'][k] for k in split_files.keys()} splits.append({ 'test_files': split_files, 'references': split_references }) return splits if __name__ == '__main__': # Test the data processor config = { 'data_dir': './data', 'dataset': 'aishell', 'test_sets': data_tag } processor = AISHELLDataProcessor(config) data_info = processor.prepare_data() print(f"Prepared {len(data_info['test_files'])} test files")