Commit a7231794 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

make post-processing more efficient

parent 1176f1bb
...@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import DatasetDict, load_dataset from datasets import DatasetDict, load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
...@@ -18,7 +19,7 @@ from transformers import ( ...@@ -18,7 +19,7 @@ from transformers import (
) )
logger = logging.getLogger(__name__) logger = get_logger(__name__, log_level="INFO")
@dataclass @dataclass
...@@ -223,6 +224,7 @@ For example, given the following keywords: 'female', 'slightly roomy sounding', ...@@ -223,6 +224,7 @@ For example, given the following keywords: 'female', 'slightly roomy sounding',
For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:" For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:"
""" """
def main(): def main():
# 1. Parse input arguments # 1. Parse input arguments
parser = HfArgumentParser((ModelArguments, DataArguments)) parser = HfArgumentParser((ModelArguments, DataArguments))
...@@ -235,7 +237,6 @@ def main(): ...@@ -235,7 +237,6 @@ def main():
# 2. Setup logging # 2. Setup logging
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logger.setLevel(logging.INFO)
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
...@@ -368,6 +369,12 @@ def main(): ...@@ -368,6 +369,12 @@ def main():
output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id) output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
return output_ids return output_ids
def postprocess_dataset(sample):
prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
generated_text = tokenizer.decode(sample["generated_ids"], skip_special_tokens=True)
sample["text_description"] = generated_text[len(prompt_text) :]
return sample
for split in vectorized_datasets: for split in vectorized_datasets:
data_loader = DataLoader( data_loader = DataLoader(
vectorized_datasets[split], vectorized_datasets[split],
...@@ -382,21 +389,16 @@ def main(): ...@@ -382,21 +389,16 @@ def main():
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generated_ids = generate_step(batch) generated_ids = generate_step(batch)
generated_ids = accelerator.gather_for_metrics(generated_ids) generated_ids = accelerator.gather_for_metrics(generated_ids)
all_generated_ids.extend(generated_ids.cpu()) all_generated_ids.extend(generated_ids.cpu().numpy())
def postprocess_dataset(sample, idx): vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids)
prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
generated_text = tokenizer.decode(all_generated_ids[idx], skip_special_tokens=True)
sample["text_description"] = generated_text[len(prompt_text) :]
return sample
if accelerator.is_main_process: if accelerator.is_main_process:
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
desc="Postprocessing dataset", desc="Postprocessing dataset",
remove_columns=["input_ids"], remove_columns=["input_ids", "generated_ids"],
with_indices=True,
) )
if accelerator.is_main_process: if accelerator.is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment