prepare_edacc.py 7.05 KB
Newer Older
sanchit-gandhi's avatar
sanchit-gandhi committed
1
2
import csv
import os
sanchit-gandhi's avatar
sanchit-gandhi committed
3
4
import re
import shutil
sanchit-gandhi's avatar
sanchit-gandhi committed
5
6
import sys
from dataclasses import dataclass, field
sanchit-gandhi's avatar
sanchit-gandhi committed
7

sanchit-gandhi's avatar
style  
sanchit-gandhi committed
8
import soundfile as sf
sanchit-gandhi's avatar
sanchit-gandhi committed
9
from datasets import Audio, Dataset, DatasetDict, load_dataset
sanchit-gandhi's avatar
sanchit-gandhi committed
10
from tqdm import tqdm
sanchit-gandhi's avatar
sanchit-gandhi committed
11
12
13
14
15
16
17
18
from transformers import HfArgumentParser


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our data for prepareation
    """
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
19

sanchit-gandhi's avatar
sanchit-gandhi committed
20
21
22
23
    dataset_dir: str = field(
        default=None,
        metadata={
            "help": "Path where the EdAcc tar.gz archive is extracted. Leave in it's raw format: the script will "
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
24
25
            "assume it's unchanged from the download and use relative paths to load the relevant audio files."
        },
sanchit-gandhi's avatar
sanchit-gandhi committed
26
27
28
29
30
31
32
33
    )
    output_dir: str = field(
        default=None,
        metadata={
            "help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the "
            "original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'."
        },
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
34
35
36
37
    overwrite_output_dir: bool = field(
        default=True,
        metadata={"help": "Overwrite the content of the output directory."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
38
39
40
41
    push_to_hub: bool = field(
        default=False,
        metadata={"help": "Whether or not to push the processed dataset to the Hub."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
42
43
44
45
    hub_dataset_id: str = field(
        default=False,
        metadata={"help": "Repository namespace if pushing to the Hugging Face Hub."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
46
47
48
49
    private_repo: bool = field(
        default=True,
        metadata={"help": "Whether or not to push the processed dataset to a private repository on the Hub"},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
50
51
52
53
    max_samples: int = field(
        default=None,
        metadata={"help": "Maximum number of samples per split. Useful for debugging purposes."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
54

sanchit-gandhi's avatar
style  
sanchit-gandhi committed
55

sanchit-gandhi's avatar
sanchit-gandhi committed
56
57
58
59
60
61
62
63
64
65
66
def main():
    # 1. Parse input arguments
    parser = HfArgumentParser(DataTrainingArguments)
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        data_args = parser.parse_args_into_dataclasses()[0]

    # 1. Load accents for each speaker
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
67
    linguistic_background = {}
sanchit-gandhi's avatar
sanchit-gandhi committed
68
69
70
71
    linguistic_background_csv = os.path.join(data_args.dataset_dir, "linguistic_background.csv")
    with open(linguistic_background_csv, encoding="utf-8") as file:
        reader = csv.DictReader(file, delimiter=",")
        for line in reader:
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
72
73
74
            linguistic_background[line["PARTICIPANT_ID"]] = line[
                "How would you describe your accent in English? (e.g. Italian, Glaswegian)"
            ]
sanchit-gandhi's avatar
sanchit-gandhi committed
75

sanchit-gandhi's avatar
sanchit-gandhi committed
76
    accent_dataset = load_dataset("edinburghcstr/edacc_accents", split="train")
sanchit-gandhi's avatar
sanchit-gandhi committed
77
78

    def format_dataset(batch):
sanchit-gandhi's avatar
sanchit-gandhi committed
79
80
81
        batch["speaker_id"] = (
            batch["Final-Participant_ID"].replace("EAEC", "EDACC").replace("P1", "-A").replace("P2", "-B")
        )
sanchit-gandhi's avatar
sanchit-gandhi committed
82
83
84
85
        return batch

    accent_dataset = accent_dataset.map(format_dataset, remove_columns=["Final-Participant_ID"])

sanchit-gandhi's avatar
sanchit-gandhi committed
86
    # 2. Clean accents for each speaker
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
87
    linguistic_background_clean = {
sanchit-gandhi's avatar
sanchit-gandhi committed
88
89
        participant: accent.strip()
        for participant, accent in zip(accent_dataset["speaker_id"], accent_dataset["English_Variety"])
sanchit-gandhi's avatar
sanchit-gandhi committed
90
91
    }
    linguistic_variety = {
sanchit-gandhi's avatar
sanchit-gandhi committed
92
        participant: l1.strip() for participant, l1 in zip(accent_dataset["speaker_id"], accent_dataset["L1_Variety"])
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
93
    }
sanchit-gandhi's avatar
sanchit-gandhi committed
94

sanchit-gandhi's avatar
sanchit-gandhi committed
95
96
97
98
99
100
101
102
103
    # 3. Initialize dataset dict
    raw_datasets = DatasetDict()

    if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir):
        shutil.rmtree(data_args.output_dir)
    output_dir_processed = os.path.join(data_args.output_dir, "processed")

    # 4. Iterate over dev/test files
    for split, split_formatted in zip(["dev", "test"], ["validation", "test"]):
sanchit-gandhi's avatar
sanchit-gandhi committed
104
105
        data_dir = os.path.join(data_args.dataset_dir, split)
        metadata = os.path.join(data_dir, "stm")
sanchit-gandhi's avatar
sanchit-gandhi committed
106
107
108
109
110
111
112
113
114
115
        output_dir_split = os.path.join(output_dir_processed, split)
        os.makedirs(output_dir_split, exist_ok=True)

        all_speakers = []
        all_genders = []
        all_l1s = []
        all_texts = []
        all_audio_paths = []
        all_normalized_accents = []
        all_raw_accents = []
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
116

sanchit-gandhi's avatar
sanchit-gandhi committed
117
118
119
120
121
        current_audio = None
        current_audio_array = None
        current_sampling_rate = None
        current_counter = 1

sanchit-gandhi's avatar
style  
sanchit-gandhi committed
122
123
        gender_pat = r".*?\<(.*),.*"
        l1_pat = r".*?\,(.*)>.*"
sanchit-gandhi's avatar
sanchit-gandhi committed
124
125

        with open(metadata, "r") as file:
sanchit-gandhi's avatar
sanchit-gandhi committed
126
            for idx, line in tqdm(enumerate(file), desc=split):
sanchit-gandhi's avatar
sanchit-gandhi committed
127
128
                # example line is: 'EDACC-C06 1 EDACC-C06-A 0.00 5.27 <male,l1> C ELEVEN DASH P ONE\n
                # the transcription always comes to the right of the last rangle bracket
sanchit-gandhi's avatar
sanchit-gandhi committed
129
                text_idx = line.find(">") + 1
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
130
                all_texts.append(line[text_idx + 1 : -1])
sanchit-gandhi's avatar
sanchit-gandhi committed
131
132
                # the metadata immediately proceeds this
                line = line[:text_idx]
sanchit-gandhi's avatar
sanchit-gandhi committed
133
134
135
136
137
138
139
140
141
                file, channel, speaker, start, end, gender_l1 = line.split(" ")

                # add speaker information to cumulative lists
                all_raw_accents.append(linguistic_background[speaker])
                all_normalized_accents.append(linguistic_background_clean[speaker])
                all_speakers.append(speaker)

                # add gender/l1 information
                all_genders.append(re.search(gender_pat, gender_l1).group(1))
sanchit-gandhi's avatar
sanchit-gandhi committed
142
                all_l1s.append(linguistic_variety[speaker])
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
143

sanchit-gandhi's avatar
sanchit-gandhi committed
144
145
                # read audio file if different from previous
                if file != current_audio:
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
146
147
148
                    current_audio_array, current_sampling_rate = sf.read(
                        os.path.join(data_args.dataset_dir, "data", file + ".wav")
                    )
sanchit-gandhi's avatar
sanchit-gandhi committed
149
150
151
152
153
154
155
156
157
                    current_audio = file
                    current_counter = 1
                else:
                    current_counter += 1

                # chunk audio file according to start/end times
                start = int(float(start) * current_sampling_rate)
                end = int(float(end) * current_sampling_rate)
                end = min(end, len(current_audio_array))
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
158
                chunked_audio = current_audio_array[start:end]
sanchit-gandhi's avatar
sanchit-gandhi committed
159
160
161
                save_path = os.path.join(output_dir_split, f"{file}-{current_counter}.wav")
                sf.write(save_path, chunked_audio, current_sampling_rate)
                all_audio_paths.append(save_path)
sanchit-gandhi's avatar
sanchit-gandhi committed
162

sanchit-gandhi's avatar
sanchit-gandhi committed
163
164
                if data_args.max_samples is not None and (data_args.max_samples - 1) == idx:
                    break
sanchit-gandhi's avatar
sanchit-gandhi committed
165

sanchit-gandhi's avatar
sanchit-gandhi committed
166
        raw_datasets[split_formatted] = Dataset.from_dict(
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
167
168
169
170
171
172
            {
                "speaker": all_speakers,
                "text": all_texts,
                "accent": all_normalized_accents,
                "raw_accent": all_raw_accents,
                "gender": all_genders,
sanchit-gandhi's avatar
sanchit-gandhi committed
173
                "l1": all_l1s,
sanchit-gandhi's avatar
style  
sanchit-gandhi committed
174
175
                "audio": all_audio_paths,
            }
sanchit-gandhi's avatar
sanchit-gandhi committed
176
        ).cast_column("audio", Audio())
sanchit-gandhi's avatar
sanchit-gandhi committed
177

sanchit-gandhi's avatar
sanchit-gandhi committed
178
179
    if data_args.push_to_hub:
        raw_datasets.push_to_hub(data_args.hub_dataset_id, token=True)
sanchit-gandhi's avatar
sanchit-gandhi committed
180

sanchit-gandhi's avatar
sanchit-gandhi committed
181
    raw_datasets.save_to_disk(data_args.output_dir)
sanchit-gandhi's avatar
sanchit-gandhi committed
182

sanchit-gandhi's avatar
style  
sanchit-gandhi committed
183

sanchit-gandhi's avatar
sanchit-gandhi committed
184
185
if __name__ == "__main__":
    main()