Commit 1fe3fc1e authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

add possibility to use metadata prompt column name instead of data

parent cb109592
...@@ -723,6 +723,7 @@ def load_multiple_datasets( ...@@ -723,6 +723,7 @@ def load_multiple_datasets(
metadata_dataset_name = dataset_dict["metadata_dataset_name"] metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None: if metadata_dataset_name is not None:
logger.info(f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}')
metadata_dataset = load_dataset( metadata_dataset = load_dataset(
metadata_dataset_name, metadata_dataset_name,
dataset_dict["config"], dataset_dict["config"],
...@@ -732,7 +733,7 @@ def load_multiple_datasets( ...@@ -732,7 +733,7 @@ def load_multiple_datasets(
) )
# TODO(YL): I forgot to create unique ids for MLS english. # TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "stable-speech/mls_eng_10k": # if dataset_dict["name"] == "stable-speech/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time): # def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"} # return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
...@@ -760,7 +761,7 @@ def load_multiple_datasets( ...@@ -760,7 +761,7 @@ def load_multiple_datasets(
# We might have applied some transformations to the prompts (e.g punctuation restoration) # We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset # so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names: if prompt_column_name in dataset.column_names:
print(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']") logger.info(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']")
dataset.remove_columns(prompt_column_name) dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
...@@ -772,10 +773,6 @@ def load_multiple_datasets( ...@@ -772,10 +773,6 @@ def load_multiple_datasets(
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}") raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
# TODO: remove
print("dataset", dataset)
print(dataset[0][prompt_column_name])
dataset_features = dataset.features.keys() dataset_features = dataset.features.keys()
if columns_to_keep is not None: if columns_to_keep is not None:
...@@ -996,6 +993,9 @@ def main(): ...@@ -996,6 +993,9 @@ def main():
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name, id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(), columns_to_keep=columns_to_keep.values(),
prompt_column_name=data_args.prompt_column_name,
audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment