Commit 1fe3fc1e authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

add possibility to use metadata prompt column name instead of data

parent cb109592
......@@ -723,6 +723,7 @@ def load_multiple_datasets(
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
logger.info(f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}')
metadata_dataset = load_dataset(
metadata_dataset_name,
dataset_dict["config"],
......@@ -732,7 +733,7 @@ def load_multiple_datasets(
)
# TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "stable-speech/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
......@@ -760,7 +761,7 @@ def load_multiple_datasets(
# We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names:
print(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']")
logger.info(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']")
dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
......@@ -771,10 +772,6 @@ def load_multiple_datasets(
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
# TODO: remove
print("dataset", dataset)
print(dataset[0][prompt_column_name])
dataset_features = dataset.features.keys()
......@@ -996,6 +993,9 @@ def main():
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(),
prompt_column_name=data_args.prompt_column_name,
audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment