Commit a7231794 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

make post-processing more efficient

parent 1176f1bb
...@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import DatasetDict, load_dataset from datasets import DatasetDict, load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
...@@ -18,7 +19,7 @@ from transformers import ( ...@@ -18,7 +19,7 @@ from transformers import (
) )
logger = logging.getLogger(__name__) logger = get_logger(__name__, log_level="INFO")
@dataclass @dataclass
...@@ -216,13 +217,14 @@ PROMPT = """You will be given six descriptive keywords related to an audio sampl ...@@ -216,13 +217,14 @@ PROMPT = """You will be given six descriptive keywords related to an audio sampl
5. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast) 5. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast)
6. The pitch of the speaker's voice (e.g., very low pitch, quite low pitch, slightly low pitch, moderate pitch, slightly high pitch, quite high pitch, very high pitch) 6. The pitch of the speaker's voice (e.g., very low pitch, quite low pitch, slightly low pitch, moderate pitch, slightly high pitch, quite high pitch, very high pitch)
Your task is to create a text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You should rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', include terms like 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', include terms like 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description. Your task is to create a text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You should rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', include terms like 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', include terms like 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description.
For example, given the following keywords: 'female', 'slightly roomy sounding', 'slightly noisy', 'very expressive', 'slightly low pitch', 'very slowly', a valid description would be: 'a woman with a deep voice speaks slowly but has an animated delivery in an echoey room with some background noise'. For example, given the following keywords: 'female', 'slightly roomy sounding', 'slightly noisy', 'very expressive', 'slightly low pitch', 'very slowly', a valid description would be: 'a woman with a deep voice speaks slowly but has an animated delivery in an echoey room with some background noise'.
For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:" For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:"
""" """
def main(): def main():
# 1. Parse input arguments # 1. Parse input arguments
parser = HfArgumentParser((ModelArguments, DataArguments)) parser = HfArgumentParser((ModelArguments, DataArguments))
...@@ -235,7 +237,6 @@ def main(): ...@@ -235,7 +237,6 @@ def main():
# 2. Setup logging # 2. Setup logging
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logger.setLevel(logging.INFO)
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
...@@ -368,6 +369,12 @@ def main(): ...@@ -368,6 +369,12 @@ def main():
output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id) output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
return output_ids return output_ids
def postprocess_dataset(sample):
prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
generated_text = tokenizer.decode(sample["generated_ids"], skip_special_tokens=True)
sample["text_description"] = generated_text[len(prompt_text) :]
return sample
for split in vectorized_datasets: for split in vectorized_datasets:
data_loader = DataLoader( data_loader = DataLoader(
vectorized_datasets[split], vectorized_datasets[split],
...@@ -382,21 +389,16 @@ def main(): ...@@ -382,21 +389,16 @@ def main():
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generated_ids = generate_step(batch) generated_ids = generate_step(batch)
generated_ids = accelerator.gather_for_metrics(generated_ids) generated_ids = accelerator.gather_for_metrics(generated_ids)
all_generated_ids.extend(generated_ids.cpu()) all_generated_ids.extend(generated_ids.cpu().numpy())
def postprocess_dataset(sample, idx): vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids)
prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
generated_text = tokenizer.decode(all_generated_ids[idx], skip_special_tokens=True)
sample["text_description"] = generated_text[len(prompt_text) :]
return sample
if accelerator.is_main_process: if accelerator.is_main_process:
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
desc="Postprocessing dataset", desc="Postprocessing dataset",
remove_columns=["input_ids"], remove_columns=["input_ids", "generated_ids"],
with_indices=True,
) )
if accelerator.is_main_process: if accelerator.is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment