Commit d175b832 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

finish prompt and launch

parent c3cc45a5
#!/usr/bin/env bash #!/usr/bin/env bash
python run_prompt_creation.py \ python run_prompt_creation.py \
--dataset_name "ylacombe/libritts_r_test_tag" \ --dataset_name "ylacombe/libritts_r_tags_and_text" \
--dataset_config_name "default" \ --dataset_config_name "clean" \
--dataset_split_name "dev.clean" \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \ --model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--per_device_eval_batch_size 512 \ --per_device_eval_batch_size 256 \
--attn_implementation "flash_attention_2" \
--dataloader_num_workers 4 \ --dataloader_num_workers 4 \
--output_dir "./" \ --output_dir "./" \
--load_in_4bit \ --load_in_4bit \
--push_to_hub \ --push_to_hub \
--hub_dataset_id "sanchit-gandhi/libritts_r_test_tag_generated" --hub_dataset_id "sanchit-gandhi/libritts_r_tags_and_text_generated"
...@@ -95,6 +95,9 @@ class ModelArguments: ...@@ -95,6 +95,9 @@ class ModelArguments:
max_new_tokens: Optional[int] = field( max_new_tokens: Optional[int] = field(
default=256, metadata={"help": "Maximum number of new tokens during generation"} default=256, metadata={"help": "Maximum number of new tokens during generation"}
) )
compile_generate: Optional[bool] = field(
default=False, metadata={"help": "Whether to compile the forward pass (not sampling) in generate."}
)
@dataclass @dataclass
...@@ -194,7 +197,7 @@ def get_kbit_device_map() -> Union[Dict[str, int], None]: ...@@ -194,7 +197,7 @@ def get_kbit_device_map() -> Union[Dict[str, int], None]:
@dataclass @dataclass
class DataCollatorWithPadding: class DataCollatorWithPadding:
""" """
Data collator that will dynamically pad the inputs received. Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
""" """
tokenizer: Any tokenizer: Any
...@@ -232,7 +235,46 @@ def main(): ...@@ -232,7 +235,46 @@ def main():
logger.info("Cleaning output dir from previous run...") logger.info("Cleaning output dir from previous run...")
shutil.rmtree(data_args.output_dir) shutil.rmtree(data_args.output_dir)
# 3. Load pre-trained model # 3. Load annotated dataset
logger.info("*** Load annotated dataset ***")
if data_args.dataset_split_name is not None:
raw_datasets = DatasetDict()
data_splits = data_args.dataset_split_name.split("+")
# load on a split-wise basis
for split in data_splits:
raw_datasets[split] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=split,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
else:
# load all splits for annotation
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys())
if data_args.max_eval_samples is not None:
for split in raw_datasets:
raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
# TODO(SG): add accent
EXPECTED_COLUMNS = {"gender", "pitch", "noise", "reverberation", "speech_monotony", "speaking_rate"}
if not EXPECTED_COLUMNS.issubset(raw_datasets_features):
missing_columns = EXPECTED_COLUMNS - raw_datasets_features
raise ValueError(
f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}"
)
# 4. Load pre-trained model
logger.info("*** Load pretrained model ***") logger.info("*** Load pretrained model ***")
torch_dtype = ( torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
...@@ -251,6 +293,28 @@ def main(): ...@@ -251,6 +293,28 @@ def main():
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
token=model_args.token, token=model_args.token,
).eval() ).eval()
if model_args.compile_generate:
if not callable(getattr(model, "_setup_cache", None)):
raise ValueError(
f"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--compile_generate=False"
"for dynamic k/v cache"
)
model.generation_config.cache_implementation = "static"
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
revision=model_args.model_revision, revision=model_args.model_revision,
...@@ -262,51 +326,20 @@ def main(): ...@@ -262,51 +326,20 @@ def main():
tokenizer.pad_token_id = tokenizer.bos_token_id tokenizer.pad_token_id = tokenizer.bos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = model.generation_config.eos_token_id
# 4. Load annotation dataset
if data_args.dataset_split_name is not None:
raw_datasets = DatasetDict()
data_splits = data_args.data_split_name.split("+")
# load on a split-wise basis
for split in data_splits:
raw_datasets[split] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=split,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
)
else:
# load all splits for annotation
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
)
raw_datasets_features = set(raw_datasets.features.keys())
if data_args.max_eval_samples:
raw_datasets = raw_datasets.select(range(data_args.max_eval_samples))
# EXPECTED_COLUMNS = {"speaking_rate", "noise", "reverberation", "speech_monotony", "gender"}
EXPECTED_COLUMNS = {"speaking_rate", "gender"}
if not EXPECTED_COLUMNS.issubset(raw_datasets_features):
missing_columns = EXPECTED_COLUMNS - raw_datasets_features
raise ValueError(
f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}"
)
PROMPT = """ We have seven keywords that describe different attributes of an audio sample spoken by a given speaker: the speaker's gender, the speaker's accent, the amount of reverberation in the sample (high or low reverberation), the amount of noise in the sample (how clear or noisy), how monotone or animated the sample is, the speaker's pitch (high or low voice), the speaker's speed (how fast or slow the speaker is speaking). PROMPT = """ We have seven keywords that describe different attributes of an audio sample spoken by a given speaker: the speaker's gender, the speaker's accent, the amount of reverberation in the sample (high or low reverberation), the amount of noise in the sample (how clear or noisy), how monotone or animated the sample is, the speaker's pitch (high or low voice), the speaker's speed (how fast or slow the speaker is speaking).
Given these keywords, form a coherent sentence that summarises the seven attributes in a meaningful way. You can change the order of the keywords in the sentence and use common synonyms for these words, provided that the sentence summarises the attributes clearly. Keep the sentence simple - don't introduce additional information other than the keywords provided. Only return the generated sentence, not any other assistant remarks. Given these keywords, form a coherent sentence that summarises the seven attributes in a meaningful way. You can change the order of the keywords in the sentence and use common synonyms for these words, provided that the sentence summarises the attributes clearly. Keep the sentence simple - don't introduce additional information other than the keywords provided. Only return the generated sentence, not any other assistant remarks.
For example, given the following descriptors: 'female', 'Hungarian', 'slightly roomy sounding', 'fairly noisy', 'quite monotone', 'fairly low pitch', 'very slowly', a valid sentence would be: 'a woman with a deep voice speaking slowly and somewhat monotonously with a Hungarian accent in an echoey room with background noise'. Note how the seven attributes have been combined together in a simple sentence, with the ordering changed but no additional information added. For example, given the following descriptors: 'female', 'Hungarian', 'slightly roomy sounding', 'fairly noisy', 'quite monotone', 'fairly low pitch', 'very slowly', a valid sentence would be: 'a woman with a deep voice speaking slowly and somewhat monotonously with a Hungarian accent in an echoey room with background noise'. Note how the seven attributes have been combined together in a simple sentence, with the ordering changed but no additional information added.
For the descriptors: {gender}, {accent}, {reverberation}, {noise}, {monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:""" For the descriptors: {gender}, {accent}, {reverberation}, {noise}, {monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:"""
SUBSET_PROMPT = """ We have six keywords that describe different attributes of an audio sample spoken by a given speaker: the speaker's gender, the amount of reverberation in the sample (high or low reverberation), the amount of noise in the sample (how clear or noisy), how monotone or animated the sample is, the speaker's pitch (high or low voice), the speaker's speed (how fast or slow the speaker is speaking).
Given these keywords, form a coherent sentence that summarises the six attributes in a meaningful way. You can change the order of the keywords in the sentence and use common synonyms for these words, provided that the sentence summarises the attributes clearly. Keep the sentence simple - don't introduce additional information other than the keywords provided. Only return the generated sentence, not any other assistant remarks.
For example, given the following descriptors: 'female', 'slightly roomy sounding', 'fairly noisy', 'quite monotone', 'fairly low pitch', 'very slowly', a valid sentence would be: 'a woman with a deep voice speaking slowly and somewhat monotonously in an echoey room with background noise'. Note how the six attributes have been combined together in a simple sentence, with the ordering changed but no additional information added.
For the descriptors: {gender}, {accent}, {reverberation}, {noise}, {monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:"""
def prepare_dataset(sample): def prepare_dataset(sample):
sample_prompt = PROMPT sample_prompt = SUBSET_PROMPT
for key in EXPECTED_COLUMNS: for key in EXPECTED_COLUMNS:
sample_prompt = sample_prompt.replace(f"[{key}]", sample[key]) sample_prompt = sample_prompt.replace(f"[{key}]", sample[key])
sample_prompt = [{"role": "user", "content": sample_prompt}] sample_prompt = [{"role": "user", "content": sample_prompt}]
...@@ -319,17 +352,9 @@ def main(): ...@@ -319,17 +352,9 @@ def main():
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts" prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts"
) )
data_collator = DataCollatorWithPadding(tokenizer)
data_loader = DataLoader(
vectorized_datasets,
batch_size=model_args.per_device_eval_batch_size,
collate_fn=data_collator,
num_workers=data_args.dataloader_num_workers,
pin_memory=True,
)
# Prepare everything with our `accelerator` # Prepare everything with our `accelerator`
model, data_loader = accelerator.prepare(model, data_loader) model = accelerator.prepare(model)
data_collator = DataCollatorWithPadding(tokenizer)
def generate_step(batch): def generate_step(batch):
output_ids = accelerator.unwrap_model(model).generate( output_ids = accelerator.unwrap_model(model).generate(
...@@ -342,13 +367,21 @@ def main(): ...@@ -342,13 +367,21 @@ def main():
output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id) output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
return output_ids return output_ids
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=model_args.per_device_eval_batch_size,
collate_fn=data_collator,
num_workers=data_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_generated_ids = [] all_generated_ids = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generated_ids = generate_step(batch) generated_ids = generate_step(batch)
all_generated_ids.extend(generated_ids.cpu()) all_generated_ids.extend(generated_ids.cpu())
accelerator.end_training()
def postprocess_dataset(sample, idx): def postprocess_dataset(sample, idx):
prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True) prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
generated_text = tokenizer.decode(all_generated_ids[idx], skip_special_tokens=True) generated_text = tokenizer.decode(all_generated_ids[idx], skip_special_tokens=True)
...@@ -356,13 +389,17 @@ def main(): ...@@ -356,13 +389,17 @@ def main():
return sample return sample
if accelerator.is_main_process: if accelerator.is_main_process:
vectorized_datasets = vectorized_datasets.map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
desc="Postprocessing dataset", desc="Postprocessing dataset",
remove_columns=["input_ids"], remove_columns=["input_ids"],
with_indices=True, with_indices=True,
) )
accelerator.end_training()
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.output_dir) vectorized_datasets.save_to_disk(data_args.output_dir)
if data_args.push_to_hub: if data_args.push_to_hub:
vectorized_datasets.push_to_hub(data_args.hub_dataset_id) vectorized_datasets.push_to_hub(data_args.hub_dataset_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment