prepare_edacc.py 9.31 KB
Newer Older
sanchit-gandhi's avatar
sanchit-gandhi committed
1
2
import csv
import os
sanchit-gandhi's avatar
sanchit-gandhi committed
3
4
import re
import shutil
sanchit-gandhi's avatar
sanchit-gandhi committed
5
6
import sys
from dataclasses import dataclass, field
sanchit-gandhi's avatar
sanchit-gandhi committed
7
8
9

from datasets import DatasetDict, Dataset, Audio
from tqdm import tqdm
sanchit-gandhi's avatar
sanchit-gandhi committed
10
from transformers import HfArgumentParser
sanchit-gandhi's avatar
sanchit-gandhi committed
11
import soundfile as sf
sanchit-gandhi's avatar
sanchit-gandhi committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our data for prepareation
    """
    dataset_dir: str = field(
        default=None,
        metadata={
            "help": "Path where the EdAcc tar.gz archive is extracted. Leave in it's raw format: the script will "
                    "assume it's unchanged from the download and use relative paths to load the relevant audio files."
        }
    )
    output_dir: str = field(
        default=None,
        metadata={
            "help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the "
            "original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'."
        },
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
33
34
35
36
    overwrite_output_dir: bool = field(
        default=True,
        metadata={"help": "Overwrite the content of the output directory."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
37
38
39
40
    push_to_hub: bool = field(
        default=False,
        metadata={"help": "Whether or not to push the processed dataset to the Hub."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
41
42
43
44
    hub_dataset_id: str = field(
        default=False,
        metadata={"help": "Repository namespace if pushing to the Hugging Face Hub."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
45
46
47
48
    private_repo: bool = field(
        default=True,
        metadata={"help": "Whether or not to push the processed dataset to a private repository on the Hub"},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
49
50
51
52
    max_samples: int = field(
        default=None,
        metadata={"help": "Maximum number of samples per split. Useful for debugging purposes."},
    )
sanchit-gandhi's avatar
sanchit-gandhi committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

ACCENT_MAPPING = {
    'Italian': 'Italian',
    'International': 'Unknown',
    'American': 'American',
    'English': 'English',
    'Latin American': 'Latin American',
    'British': 'English',
    'Romanian': 'Romanian',
    'Standard Indian English': 'Indian',
    'Trans-Atlantic': 'Unknown',
    'Slightly American': 'American',
    'European': 'Unknown',
    'Scottish (Fife)': 'Scottish',
    'English with Scottish inflections': 'Scottish',
    'Indian': 'Indian',
    'Asian': 'Asian',
    'NA': 'Unknown',
    'German': 'German',
    'South London': 'English',
    'Dutch': 'Dutch',
    'Mostly West Coast American with some Australian Intonation': 'American',
    'Japanese': 'Japanese',
    'Chinese': 'Chinese',
    'Generic middle class white person': 'English',
    'French': 'French',
    'Chinese accent or mixed accent(US, UK, China..) perhaps': 'Chinese',
    'American accent': 'American',
    'Catalan': 'Catalan',
    'American, I guess.': 'American',
    'Spanish American': 'Latin American',
    'Spanish': 'Spanish',
    'Standard American,Scottish': 'American',
    'Bulgarian': 'Bulgarian',
    'Latin': 'Latin American',
    'Latín American': 'Latin American',
    'Mexican': 'Latin American', # TODO: un-generalise latin american accents?
    'North American': 'American',
    'Afrian': 'African',
    'Nigerian': 'African', # TODO: un-generalise african accents?
    'East-European': 'Eastern European',
    'Eastern European': 'Eastern European',
    'Southern London': 'English',
    'American with a slight accent': 'American',
    'American-ish': 'American',
    'Indian / Pakistani accent': 'Indian',
    'Pakistani/American': 'Pakistani',
    'African accent': 'African',
    'Kenyan': 'African',  # TODO: un-generalise african accents?
    'Ghanaian': 'African', # TODO: un-generalise african accents?
    'Spanish accent': 'Spanish',
    'Lithuanian': 'Lithuanian',
    'Lithuanian (eastern European)': 'Lithuanian',
    'Indonesian': 'Indonesian',
    'Egyptian': 'Egyptian',
    'South African English': 'South African',
    "Neutral": "English",
    'Neutral accent': 'English',
    'Neutral English, Italian': 'English',
    'Fluent': 'Unknown',
    'Glaswegian': 'Scottish',
    'Glaswegian (not slang)': 'Scottish',
    'Irish': 'Irish',
    'Jamaican': 'Jamaican',
    'Jamaican accent': 'Jamaican',
    'Irish/ Dublin': 'Irish',
    'South Dublin Irish': 'Irish',
    'italian': 'Italian',
    'italian mixed with American and British English': 'Italian',
    'Italian mixed with American accent': 'Italian',
    'South American': 'Latin American',
    'Brazilian accent': 'Latin American', # TODO: un-generalise latin american accents?
    'Israeli': 'Israeli',
    'Vietnamese accent': 'Vietnamese',
    'Southern Irish': 'Irish',
    'Slight Vietnamese accent': 'Vietnamese',
    'Midwestern United States': 'American',
    'Vietnamese English': 'Vietnamese',
    "Vietnamese": "Vietnamese",
    "": "Unknown"
}


def main():
    # 1. Parse input arguments
    parser = HfArgumentParser(DataTrainingArguments)
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        data_args = parser.parse_args_into_dataclasses()[0]

    # 1. Load accents for each speaker
    linguistic_background = dict()
    linguistic_background_csv = os.path.join(data_args.dataset_dir, "linguistic_background.csv")
    with open(linguistic_background_csv, encoding="utf-8") as file:
        reader = csv.DictReader(file, delimiter=",")
        for line in reader:
            linguistic_background[line["PARTICIPANT_ID"]] = line["How would you describe your accent in English? (e.g. Italian, Glaswegian)"]

    # 2. Clean accents for each speaker
    linguistic_background_clean = {participant: ACCENT_MAPPING[accent.strip()] for participant, accent in linguistic_background.items()}

sanchit-gandhi's avatar
sanchit-gandhi committed
157
158
159
160
161
162
163
164
165
    # 3. Initialize dataset dict
    raw_datasets = DatasetDict()

    if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir):
        shutil.rmtree(data_args.output_dir)
    output_dir_processed = os.path.join(data_args.output_dir, "processed")

    # 4. Iterate over dev/test files
    for split, split_formatted in zip(["dev", "test"], ["validation", "test"]):
sanchit-gandhi's avatar
sanchit-gandhi committed
166
167
        data_dir = os.path.join(data_args.dataset_dir, split)
        metadata = os.path.join(data_dir, "stm")
sanchit-gandhi's avatar
sanchit-gandhi committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        output_dir_split = os.path.join(output_dir_processed, split)
        os.makedirs(output_dir_split, exist_ok=True)

        all_speakers = []
        all_genders = []
        all_l1s = []
        all_texts = []
        all_audio_paths = []
        all_normalized_accents = []
        all_raw_accents = []
        
        current_audio = None
        current_audio_array = None
        current_sampling_rate = None
        current_counter = 1

        gender_pat = r'.*?\<(.*),.*'
        l1_pat = r'.*?\,(.*)>.*'
sanchit-gandhi's avatar
sanchit-gandhi committed
186
187

        with open(metadata, "r") as file:
sanchit-gandhi's avatar
sanchit-gandhi committed
188
            for idx, line in tqdm(enumerate(file), desc=split):
sanchit-gandhi's avatar
sanchit-gandhi committed
189
190
                # example line is: 'EDACC-C06 1 EDACC-C06-A 0.00 5.27 <male,l1> C ELEVEN DASH P ONE\n
                # the transcription always comes to the right of the last rangle bracket
sanchit-gandhi's avatar
sanchit-gandhi committed
191
192
                text_idx = line.find(">") + 1
                all_texts.append(line[text_idx + 1:-1])
sanchit-gandhi's avatar
sanchit-gandhi committed
193
194
                # the metadata immediately proceeds this
                line = line[:text_idx]
sanchit-gandhi's avatar
sanchit-gandhi committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                file, channel, speaker, start, end, gender_l1 = line.split(" ")

                # add speaker information to cumulative lists
                all_raw_accents.append(linguistic_background[speaker])
                all_normalized_accents.append(linguistic_background_clean[speaker])
                all_speakers.append(speaker)

                # add gender/l1 information
                all_genders.append(re.search(gender_pat, gender_l1).group(1))
                all_l1s.append(re.search(l1_pat, gender_l1).group(1))
                
                # read audio file if different from previous
                if file != current_audio:
                    current_audio_array, current_sampling_rate = sf.read(os.path.join(data_args.dataset_dir, "data", file + ".wav"))
                    current_audio = file
                    current_counter = 1
                else:
                    current_counter += 1

                # chunk audio file according to start/end times
                start = int(float(start) * current_sampling_rate)
                end = int(float(end) * current_sampling_rate)
                end = min(end, len(current_audio_array))
                chunked_audio = current_audio_array[start: end]
                save_path = os.path.join(output_dir_split, f"{file}-{current_counter}.wav")
                sf.write(save_path, chunked_audio, current_sampling_rate)
                all_audio_paths.append(save_path)
sanchit-gandhi's avatar
sanchit-gandhi committed
222

sanchit-gandhi's avatar
sanchit-gandhi committed
223
224
                if data_args.max_samples is not None and (data_args.max_samples - 1) == idx:
                    break
sanchit-gandhi's avatar
sanchit-gandhi committed
225

sanchit-gandhi's avatar
sanchit-gandhi committed
226
227
228
229
230
231
232
233
234
235
        raw_datasets[split_formatted] = Dataset.from_dict(
            {"speaker": all_speakers,
             "text": all_texts,
             "accent": all_normalized_accents,
             "raw_accent": all_raw_accents,
             "gender": all_genders,
             "language": all_l1s,
             "audio": all_audio_paths,
             }
        ).cast_column("audio", Audio())
sanchit-gandhi's avatar
sanchit-gandhi committed
236

sanchit-gandhi's avatar
sanchit-gandhi committed
237
238
    if data_args.push_to_hub:
        raw_datasets.push_to_hub(data_args.hub_dataset_id, token=True)
sanchit-gandhi's avatar
sanchit-gandhi committed
239

sanchit-gandhi's avatar
sanchit-gandhi committed
240
    raw_datasets.save_to_disk(data_args.output_dir)
sanchit-gandhi's avatar
sanchit-gandhi committed
241
242
243

if __name__ == "__main__":
    main()