Commit cb109592 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

add prompt column name handling

parent 9ef35aa6
...@@ -687,6 +687,7 @@ def load_multiple_datasets( ...@@ -687,6 +687,7 @@ def load_multiple_datasets(
seed: Optional[int] = None, seed: Optional[int] = None,
id_column_name: Optional[str] = None, id_column_name: Optional[str] = None,
columns_to_keep: Optional[Set[str]] = None, columns_to_keep: Optional[Set[str]] = None,
prompt_column_name: Optional[str] = None,
sampling_rate: Optional[int] = None, sampling_rate: Optional[int] = None,
audio_column_name: Optional[str] = None, audio_column_name: Optional[str] = None,
**kwargs, **kwargs,
...@@ -753,15 +754,28 @@ def load_multiple_datasets( ...@@ -753,15 +754,28 @@ def load_multiple_datasets(
elif id_column_name is not None: elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
if prompt_column_name is not None:
# We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names:
print(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']")
dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove) metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}") raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
# TODO: remove
print("dataset", dataset)
print(dataset[0][prompt_column_name])
dataset_features = dataset.features.keys() dataset_features = dataset.features.keys()
if columns_to_keep is not None: if columns_to_keep is not None:
...@@ -954,6 +968,7 @@ def main(): ...@@ -954,6 +968,7 @@ def main():
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name, id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(), columns_to_keep=columns_to_keep.values(),
prompt_column_name=data_args.prompt_column_name,
audio_column_name=data_args.target_audio_column_name, audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment