#!/usr/bin/env python3 """ Data Processing Module for ASR Pipeline This module handles data preparation and processing for the ASR pipeline, including downloading datasets, creating data splits, and preparing input files for inference. """ import logging import os import shutil import subprocess import urllib.request import tarfile from pathlib import Path from typing import Dict, List, Tuple class DataProcessor: """Base class for data processing""" def __init__(self, config: Dict): self.config = config self.logger = logging.getLogger(__name__) self.data_dir = Path(config.get('data_dir', './data')) def prepare_data(self) -> Dict: """Prepare data for inference - to be implemented by subclasses""" raise NotImplementedError def load_existing_data(self) -> Dict: """Load existing prepared data""" raise NotImplementedError class AISHELLDataProcessor(DataProcessor): """Data processor for AISHELL dataset""" def __init__(self, config: Dict): super().__init__(config) self.aishell_dir = self.data_dir / 'aishell' self.raw_data_dir = self.aishell_dir / 'raw' self.processed_data_dir = self.aishell_dir / 'processed' def prepare_data(self) -> Dict: """Prepare AISHELL dataset for inference""" self.logger.info("Preparing AISHELL dataset") # Create directories self.raw_data_dir.mkdir(parents=True, exist_ok=True) self.processed_data_dir.mkdir(parents=True, exist_ok=True) # Download dataset if not exists if not self._check_data_exists(): self.logger.info("Downloading AISHELL dataset") self._download_aishell() else: self.logger.info("AISHELL dataset already exists") # Process data self.logger.info("Processing AISHELL dataset") data_info = self._process_aishell_data() return data_info def load_existing_data(self) -> Dict: """Load existing prepared AISHELL data""" self.logger.info("Loading existing AISHELL data") if not self._check_processed_data_exists(): raise FileNotFoundError("Processed data not found. Run data preparation first.") return self._load_processed_data() def _check_data_exists(self) -> bool: """Check if AISHELL dataset exists""" required_files = [ self.raw_data_dir / 'data_aishell' / 'wav', self.raw_data_dir / 'data_aishell' / 'transcript' / 'aishell_transcript_v0.8.txt' ] return all(f.exists() for f in required_files) def _check_processed_data_exists(self) -> bool: """Check if processed data exists""" test_sets = self.config.get('test_sets', ['test']) required_files = [] for test_set in test_sets: required_files.extend([ self.processed_data_dir / test_set / 'wav.scp', self.processed_data_dir / test_set / 'text' ]) return all(f.exists() for f in required_files) def _download_aishell(self): """Download AISHELL dataset""" data_urls = [ 'http://www.openslr.org/resources/33/data_aishell.tgz', 'http://www.openslr.org/resources/33/resource_aishell.tgz' ] for url in data_urls: filename = Path(url).name filepath = self.raw_data_dir / filename self.logger.info(f"Downloading {filename}") urllib.request.urlretrieve(url, filepath) # Extract tar file self.logger.info(f"Extracting {filename}") with tarfile.open(filepath, 'r:gz') as tar: tar.extractall(self.raw_data_dir) # Remove tar file filepath.unlink() def _process_aishell_data(self) -> Dict: """Process AISHELL data into Kaldi-style format""" test_sets = self.config.get('test_sets', ['test']) data_info = {'test_files': {}, 'references': {}} for test_set in test_sets: self.logger.info(f"Processing {test_set} set") # Create output directory output_dir = self.processed_data_dir / test_set output_dir.mkdir(parents=True, exist_ok=True) # Process based on set type if test_set == 'test': audio_pattern = '*/test/*/*.wav' elif test_set == 'dev': audio_pattern = '*/dev/*/*.wav' else: # train audio_pattern = '*/train/*/*.wav' # Find audio files audio_files = list(self.raw_data_dir.rglob(audio_pattern)) self.logger.info(f"Found {len(audio_files)} audio files for {test_set}") # Create wav.scp and text files wav_scp_file = output_dir / 'wav.scp' text_file = output_dir / 'text' with open(wav_scp_file, 'w', encoding='utf-8') as wav_f, \ open(text_file, 'w', encoding='utf-8') as text_f: for audio_file in audio_files: # Extract utterance ID from filename utt_id = audio_file.stem # Write to wav.scp wav_f.write(f"{utt_id} {audio_file.resolve()}\n") # Get transcription (simplified - in real implementation, parse transcript file) # For now, we'll create dummy transcriptions transcription = f"dummy transcription for {utt_id}" text_f.write(f"{utt_id} {transcription}\n") # Store file info data_info['test_files'][utt_id] = str(audio_file.resolve()) data_info['references'][utt_id] = transcription self.logger.info(f"Created {wav_scp_file} and {text_file}") return data_info def _load_processed_data(self) -> Dict: """Load processed data from files""" test_sets = self.config.get('test_sets', ['test']) data_info = {'test_files': {}, 'references': {}} for test_set in test_sets: wav_scp_file = self.processed_data_dir / test_set / 'wav.scp' text_file = self.processed_data_dir / test_set / 'text' if not wav_scp_file.exists() or not text_file.exists(): self.logger.warning(f"Missing files for {test_set} set") continue # Load wav.scp with open(wav_scp_file, 'r', encoding='utf-8') as f: for line in f: utt_id, audio_path = line.strip().split(maxsplit=1) data_info['test_files'][utt_id] = audio_path # Load text with open(text_file, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split(maxsplit=1) if len(parts) == 2: utt_id, transcription = parts data_info['references'][utt_id] = transcription self.logger.info(f"Loaded {len(data_info['test_files'])} test files") return data_info def split_data_for_parallel_processing(data_info: Dict, nj: int) -> List[Dict]: """ Split data for parallel processing Args: data_info: Dictionary containing test files and references nj: Number of parallel jobs Returns: List of data_info dictionaries for each job """ if nj <= 1: return [data_info] # Split data evenly test_files = list(data_info['test_files'].items()) chunk_size = len(test_files) // nj splits = [] for i in range(nj): start_idx = i * chunk_size end_idx = start_idx + chunk_size if i < nj - 1 else len(test_files) split_files = dict(test_files[start_idx:end_idx]) split_references = {k: data_info['references'][k] for k in split_files.keys()} splits.append({ 'test_files': split_files, 'references': split_references }) return splits if __name__ == '__main__': # Test the data processor config = { 'data_dir': './data', 'dataset': 'aishell', 'test_sets': ['test'] } processor = AISHELLDataProcessor(config) data_info = processor.prepare_data() print(f"Prepared {len(data_info['test_files'])} test files")