clean_wenet_speech.py 3.65 KB
Newer Older
Lengyue's avatar
Lengyue committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
from pathlib import Path
import subprocess

import librosa
import soundfile as sf
import torch
import torchaudio
from fish_audio_preprocess.utils.separate_audio import (
    separate_audio,
    merge_tracks,
    init_model,
)
from tqdm import tqdm
import time
import os
import tempfile

rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
device = torch.device("cuda:0")
print(f"Rank {rank}/{world_size} on {device}")


def main():
    meta_path = Path("dataset/tts/WenetSpeech/WenetSpeech.json")
    dataset_path = Path("dataset/tts/WenetSpeech")
    cleaned_path = Path("dataset/tts/WenetSpeech/cleaned")
    if not cleaned_path.exists():
        cleaned_path.mkdir(parents=True)

    demucs = init_model("htdemucs", device)
    print("Model loaded")

    with open(meta_path) as f:
        dataset = json.load(f)["audios"]

    print(f"Dataset loaded, {len(dataset)} samples")
    dataset = dataset[rank::world_size]
    print(f"Dataset split, {len(dataset)} samples")

    for data_idx, data in enumerate(dataset):
        done_path = cleaned_path / data["aid"] / "done"
        done_path.parent.mkdir(parents=True, exist_ok=True)

        if done_path.exists():
            continue

        print(f"Processing {data_idx}/{len(dataset)} at rank {rank}")

        try:
            with tempfile.NamedTemporaryFile(suffix=".wav") as f:
                subprocess.check_call(
                    [
                        "ffmpeg",
                        "-y",
                        "-i",
                        str(dataset_path / data["path"]),
                        "-c:a",
                        "pcm_s16le",
                        "-threads",
                        "0",
                        "-ar",
                        "24000",
                        str(f.name),
                    ],
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.DEVNULL,
                )
                raw_audio, sr = librosa.load(f.name, sr=None, mono=True)

            raw_audio = torch.from_numpy(raw_audio[None]).to(device)
            audio = torchaudio.functional.resample(
                raw_audio, orig_freq=sr, new_freq=demucs.samplerate
            )
            # Make it 2 channels
            audio = torch.cat([audio, audio], dim=0)
            tracks = separate_audio(demucs, audio, shifts=1, num_workers=0, progress=False)
            audio = merge_tracks(tracks, filter=["vocals"])[0]
            vocals, sr = (
                torchaudio.functional.resample(
                    audio, orig_freq=demucs.samplerate, new_freq=24000
                ),
                24000,
            )
            vocals = vocals.cpu().numpy()

            for idx, segment in enumerate(data["segments"]):
                if segment["confidence"] <= 0.95:
                    continue

                # Load audio
                begin = int(segment["begin_time"] * sr)
                end = int(segment["end_time"] * sr)
                segment_audio = vocals[begin:end]

                # Write audio
                temp_path = cleaned_path / data["aid"] / f"S{idx:05d}.wav"
                temp_path.parent.mkdir(parents=True, exist_ok=True)
                sf.write(temp_path, segment_audio, samplerate=sr)

                # Write text
                temp_path = temp_path.with_suffix(".txt")
                temp_path.write_text(segment["text"])

            # Write done file
            done_path.write_text("")
        except Exception as e:
            print(f"Error {e} on {data_idx}/{len(dataset)} at rank {rank}")
            time.sleep(10)
            continue

    print("Done")


if __name__ == "__main__":
    main()