Unverified Commit ac17f711 authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

added max_sample args and metrics changes (#10602)

parent c19c811a
...@@ -144,6 +144,20 @@ class DataTrainingArguments: ...@@ -144,6 +144,20 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the preprocessing."},
) )
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if self.dataset_name is None and self.train_file is None and self.validation_file is None:
...@@ -317,7 +331,31 @@ def main(): ...@@ -317,7 +331,31 @@ def main():
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer(examples[text_column_name], padding="max_length", truncation=True) return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
tokenized_datasets = datasets.map( if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
# Select Sample from Dataset
train_dataset = train_dataset.select(range(data_args.max_train_samples))
# tokenize train dataset in batch
train_dataset = train_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_eval:
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
# Selecting samples from dataset
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
# tokenize validation dataset
eval_dataset = eval_dataset.map(
tokenize_function, tokenize_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -332,8 +370,8 @@ def main(): ...@@ -332,8 +370,8 @@ def main():
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
) )
...@@ -358,33 +396,27 @@ def main(): ...@@ -358,33 +396,27 @@ def main():
train_result = trainer.train(resume_from_checkpoint=checkpoint) train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") metrics = train_result.metrics
if trainer.is_world_process_zero(): max_train_samples = (
with open(output_train_file, "w") as writer: data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
logger.info("***** Train results *****") )
for key, value in sorted(train_result.metrics.items()): metrics["train_samples"] = min(max_train_samples, len(train_dataset))
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model trainer.log_metrics("train", metrics)
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation # Evaluation
results = {}
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
results = trainer.evaluate() metrics = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results_{{cookiecutter.example_shortcut}}.txt") max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
if trainer.is_world_process_zero(): metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
return results trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
def _mp_fn(index): def _mp_fn(index):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment