Commit c3cc45a5 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

load multiple splits

parent ee4f39db
...@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset, DatasetDict
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
...@@ -263,20 +263,37 @@ def main(): ...@@ -263,20 +263,37 @@ def main():
model.generation_config.pad_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = model.generation_config.eos_token_id
# 4. Load annotation dataset # 4. Load annotation dataset
raw_datasets = load_dataset( if data_args.dataset_split_name is not None:
data_args.dataset_name, raw_datasets = DatasetDict()
data_args.dataset_config_name, data_splits = data_args.data_split_name.split("+")
split=data_args.dataset_split_name, # load on a split-wise basis
cache_dir=model_args.cache_dir, for split in data_splits:
token=model_args.token, raw_datasets[split] = load_dataset(
trust_remote_code=model_args.trust_remote_code, data_args.dataset_name,
num_proc=data_args.preprocessing_num_workers, data_args.dataset_config_name,
) split=split,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
)
else:
# load all splits for annotation
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
)
raw_datasets_features = set(raw_datasets.features.keys()) raw_datasets_features = set(raw_datasets.features.keys())
if data_args.max_eval_samples: if data_args.max_eval_samples:
raw_datasets = raw_datasets.select(range(data_args.max_eval_samples)) raw_datasets = raw_datasets.select(range(data_args.max_eval_samples))
EXPECTED_COLUMNS = {"speaking_rate", "noise", "reverberation", "speech_monotony"} # EXPECTED_COLUMNS = {"speaking_rate", "noise", "reverberation", "speech_monotony", "gender"}
EXPECTED_COLUMNS = {"speaking_rate", "gender"}
if not EXPECTED_COLUMNS.issubset(raw_datasets_features): if not EXPECTED_COLUMNS.issubset(raw_datasets_features):
missing_columns = EXPECTED_COLUMNS - raw_datasets_features missing_columns = EXPECTED_COLUMNS - raw_datasets_features
raise ValueError( raise ValueError(
...@@ -289,25 +306,7 @@ def main(): ...@@ -289,25 +306,7 @@ def main():
For the descriptors: {gender}, {accent}, {reverberation}, {noise}, {monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:""" For the descriptors: {gender}, {accent}, {reverberation}, {noise}, {monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:"""
def prepare_dataset(sample): def prepare_dataset(sample):
sample_prompt = PROMPT.replace("{gender}", sample["gender"]) sample_prompt = PROMPT
sample_prompt = sample_prompt.replace("{accent}", sample["accent"])
sample_prompt = sample_prompt.replace("{reverberation}", sample["reverberation"])
sample_prompt = sample_prompt.replace("{noise}", sample["noise"])
sample_prompt = sample_prompt.replace("{monotony}", sample["monotony"])
sample_prompt = sample_prompt.replace("{pitch}", sample["pitch"])
sample_prompt = sample_prompt.replace("{speaking_rate}", sample["speaking_rate"])
sample_prompt = [{"role": "user", "content": sample_prompt}]
token_ids = tokenizer.apply_chat_template(sample_prompt)
sample["prompt_ids"] = token_ids
return sample
DUMMY_PROMPT = """ We have seven keywords that describe different attributes of an audio sample spoken by a given speaker: the speaker's gender, the speaker's accent, the amount of reverberation in the sample (high or low reverberation), the amount of noise in the sample (how clear or noisy), how monotone or animated the sample is, the speaker's pitch (high or low voice), the speaker's speed (how fast or slow the speaker is speaking).
Given these keywords, form a coherent sentence that summarises the seven attributes in a meaningful way. You can change the order of the keywords in the sentence and use common synonyms for these words, provided that the sentence summarises the attributes clearly. Keep the sentence simple - don't introduce additional information other than the keywords provided. Only return the generated sentence, not any other assistant remarks.
For example, given the following descriptors: 'female', 'Hungarian', 'slightly roomy sounding', 'fairly noisy', 'quite monotone', 'fairly low pitch', 'very slowly', a valid sentence would be: 'a woman with a deep voice speaking slowly and somewhat monotonously with a Hungarian accent in an echoey room with background noise'. Note how the seven attributes have been combined together in a simple sentence, with the ordering changed but no additional information added.
For the descriptors: [gender], [accent], [reverberation], [noise], [monotony], [pitch], [speaking_rate], the corresponding sentence is:"""
def prepare_dummy_dataset(sample):
sample_prompt = DUMMY_PROMPT
for key in EXPECTED_COLUMNS: for key in EXPECTED_COLUMNS:
sample_prompt = sample_prompt.replace(f"[{key}]", sample[key]) sample_prompt = sample_prompt.replace(f"[{key}]", sample[key])
sample_prompt = [{"role": "user", "content": sample_prompt}] sample_prompt = [{"role": "user", "content": sample_prompt}]
...@@ -317,7 +316,7 @@ def main(): ...@@ -317,7 +316,7 @@ def main():
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map( vectorized_datasets = raw_datasets.map(
prepare_dummy_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts" prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts"
) )
data_collator = DataCollatorWithPadding(tokenizer) data_collator = DataCollatorWithPadding(tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment