scp2jsonl.py 4.95 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import hydra
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Dict, Optional, Tuple
from urllib.request import urlopen

import soundfile as sf
from modelscope import AutoTokenizer
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig


class LineProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.lock = threading.Lock()

    def process_line(self, line_pair: Tuple[str, str]) -> Optional[Dict]:
        line1, line2 = line_pair

        line1, line2 = line1.strip(), line2.strip()
        if not line1 or not line2:
            return None

        parts1, parts2 = line1.split(maxsplit=1), line2.split(maxsplit=1)
        if len(parts1) != 2 or len(parts2) != 2:
            return None

        utt1, utt2 = parts1[0], parts2[0]
        wav_path, text = parts1[1], parts2[1]

        if utt1 != utt2:
            return {"error": f"UTT mismatch: {utt1} vs {utt2}"}

        try:
            if wav_path.startswith("http"):
                response = urlopen(wav_path)
                if response.status != 200:
                    return {"error": f"WAV not found: {wav_path}"}
                audio_file = BytesIO(response.read())
                duration = sf.info(audio_file).duration
            else:
                if not os.path.exists(wav_path):
                    return {"error": f"WAV not found: {wav_path}"}
                duration = sf.info(wav_path).duration

            data = {
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {
                        "role": "user",
                        "content": f"语音转写:<|startofspeech|>!{wav_path}<|endofspeech|>",
                    },
                    {"role": "assistant", "content": text},
                ],
                "speech_length": int((duration * 1000 - 25) // 10 + 1),
                "text_length": len(self.tokenizer.tokenize(text)),
            }
            return {"success": data, "utt": utt1}

        except Exception as e:
            return {"error": f"Error processing {wav_path}: {str(e)}"}


@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
    def to_plain_list(cfg_item):
        if isinstance(cfg_item, ListConfig):
            return OmegaConf.to_container(cfg_item, resolve=True)
        elif isinstance(cfg_item, DictConfig):
            return {k: to_plain_list(v) for k, v in cfg_item.items()}
        else:
            return cfg_item
    kwargs = to_plain_list(cfg)
    scp_file = kwargs["scp_file"]
    transcript_file = kwargs["transcript_file"]
    max_workers = kwargs.get("max_workers", os.cpu_count())
    jsonl_file = kwargs["jsonl_file"]

    with open(scp_file, "r") as f1, open(transcript_file, "r") as f2:
        scp_lines = f1.readlines()
        transcript_lines = f2.readlines()

    if len(scp_lines) != len(transcript_lines):
        print(
            f"Warning: Line count mismatch - scp: {len(scp_lines)}, transcript: {len(transcript_lines)}"
        )

    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
    processor = LineProcessor(tokenizer)

    data_pairs = list(zip(scp_lines, transcript_lines))

    processed_count = 0
    failed_count = 0
    error_messages = []

    with tqdm(total=len(data_pairs), desc="Processing") as pbar:
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            with open(jsonl_file, "w") as f_out:
                futures = {
                    executor.submit(processor.process_line, pair): i
                    for i, pair in enumerate(data_pairs)
                }

                for future in as_completed(futures):
                    result = future.result()

                    if result and "success" in result:
                        with processor.lock:
                            json.dump(result["success"], f_out, ensure_ascii=False)
                            f_out.write("\n")
                        processed_count += 1
                    elif result and "error" in result:
                        failed_count += 1
                        error_messages.append(result["error"])

                    pbar.update(1)
                    pbar.set_postfix(
                        {"processed": processed_count, "failed": failed_count}
                    )

    print(f"\nProcessing completed:")
    print(f"  Total lines: {len(data_pairs)}")
    print(f"  Successfully processed: {processed_count}")
    print(f"  Failed: {failed_count}")

    if error_messages and len(error_messages) <= 10:
        print(f"\nSample errors:")
        for error in error_messages[:10]:
            print(f"  - {error}")
    elif error_messages:
        print(f"\nFirst 10 errors:")
        for error in error_messages[:10]:
            print(f"  - {error}")
        print(f"  ... and {len(error_messages) - 10} more errors")


if __name__ == "__main__":
    main_hydra()