Unverified Commit ac17f711 authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

added max_sample args and metrics changes (#10602)

parent c19c811a
......@@ -144,6 +144,20 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
......@@ -317,7 +331,31 @@ def main():
def tokenize_function(examples):
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
tokenized_datasets = datasets.map(
if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
# Select Sample from Dataset
train_dataset = train_dataset.select(range(data_args.max_train_samples))
# tokenize train dataset in batch
train_dataset = train_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_eval:
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
# Selecting samples from dataset
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
# tokenize validation dataset
eval_dataset = eval_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -332,8 +370,8 @@ def main():
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
......@@ -358,33 +396,27 @@ def main():
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate()
metrics = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results_{{cookiecutter.example_shortcut}}.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
return results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
def _mp_fn(index):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment