Commit cb109592 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

add prompt column name handling

parent 9ef35aa6
......@@ -687,6 +687,7 @@ def load_multiple_datasets(
seed: Optional[int] = None,
id_column_name: Optional[str] = None,
columns_to_keep: Optional[Set[str]] = None,
prompt_column_name: Optional[str] = None,
sampling_rate: Optional[int] = None,
audio_column_name: Optional[str] = None,
**kwargs,
......@@ -753,15 +754,28 @@ def load_multiple_datasets(
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
if prompt_column_name is not None:
# We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names:
print(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']")
dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
# TODO: remove
print("dataset", dataset)
print(dataset[0][prompt_column_name])
dataset_features = dataset.features.keys()
if columns_to_keep is not None:
......@@ -954,6 +968,7 @@ def main():
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(),
prompt_column_name=data_args.prompt_column_name,
audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment