"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "accccdd0087263a1e494e9c9ec30a43043ff3905"
Unverified Commit 1d30ec95 authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)

* added changes for uniformity

* modified files

* corrected typo

* fixed qa scripts

* fix typos

* fixed predict typo in qa no trainer

* fixed test file

* reverted trainer changes

* reverted trainer changes in custom exmaples

* updated readme

* added changes in deepspeed test

* added changes for predict and eval
parent 7959d835
......@@ -50,8 +50,8 @@ For example here is how to truncate all three splits to just 50 samples each:
```
examples/pytorch/token-classification/run_ner.py \
--max_train_samples 50 \
--max_val_samples 50 \
--max_test_samples 50 \
--max_eval_samples 50 \
--max_predict_samples 50 \
[...]
```
......
......@@ -126,10 +126,10 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
......@@ -397,8 +397,8 @@ def main():
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Initialize our Trainer
trainer = Trainer(
......@@ -439,8 +439,8 @@ def main():
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity
......
......@@ -157,10 +157,10 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
......@@ -419,8 +419,8 @@ def main():
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Data collator
# This one will take care of randomly masking the tokens.
......@@ -468,8 +468,8 @@ def main():
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity
......
......@@ -154,10 +154,10 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
......@@ -397,8 +397,8 @@ def main():
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Data collator
data_collator = DataCollatorForPermutationLanguageModeling(
......@@ -444,8 +444,8 @@ def main():
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity
......
......@@ -127,10 +127,10 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
......@@ -363,8 +363,8 @@ def main():
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
......@@ -422,8 +422,8 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......
......@@ -133,17 +133,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -468,9 +468,9 @@ def main():
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_examples = datasets["validation"]
if data_args.max_val_samples is not None:
if data_args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_val_samples))
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
# Validation Feature Creation
eval_dataset = eval_examples.map(
prepare_validation_features,
......@@ -479,28 +479,28 @@ def main():
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_val_samples is not None:
if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"]
if data_args.max_test_samples is not None:
predict_examples = datasets["test"]
if data_args.max_predict_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
# Predict Feature Creation
predict_dataset = predict_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_test_samples is not None:
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples))
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
......@@ -581,8 +581,8 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......@@ -590,14 +590,16 @@ def main():
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples)
results = trainer.predict(predict_dataset, predict_examples)
metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
......
......@@ -132,17 +132,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -504,9 +504,9 @@ def main():
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_examples = datasets["validation"]
if data_args.max_val_samples is not None:
if data_args.max_eval_samples is not None:
# Selecting Eval Samples from Dataset
eval_examples = eval_examples.select(range(data_args.max_val_samples))
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
# Create Features from Eval Dataset
eval_dataset = eval_examples.map(
prepare_validation_features,
......@@ -515,28 +515,28 @@ def main():
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_val_samples is not None:
if data_args.max_eval_samples is not None:
# Selecting Samples from Dataset again since Feature Creation might increase samples size
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"]
if data_args.max_test_samples is not None:
predict_examples = datasets["test"]
if data_args.max_predict_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples))
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
# Test Feature Creation
test_dataset = test_examples.map(
predict_dataset = predict_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_test_samples is not None:
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples))
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
......@@ -620,8 +620,8 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......@@ -629,14 +629,16 @@ def main():
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples)
results = trainer.predict(predict_dataset, predict_examples)
metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
......
......@@ -183,20 +183,20 @@ def parse_args():
"value if set.",
)
parser.add_argument(
"--max_val_samples",
"--max_eval_samples",
type=int,
default=None,
help="For debugging purposes or quicker training, truncate the number of validation examples to this "
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set.",
)
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--max_test_samples",
"--max_predict_samples",
type=int,
default=None,
help="For debugging purposes or quicker training, truncate the number of test examples to this",
help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
)
args = parser.parse_args()
......@@ -481,9 +481,9 @@ def main():
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"]
if args.max_val_samples is not None:
if args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(args.max_val_samples))
eval_examples = eval_examples.select(range(args.max_eval_samples))
# Validation Feature Creation
eval_dataset = eval_examples.map(
prepare_validation_features,
......@@ -493,28 +493,28 @@ def main():
load_from_cache_file=not args.overwrite_cache,
)
if args.max_val_samples is not None:
if args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(args.max_val_samples))
eval_dataset = eval_dataset.select(range(args.max_eval_samples))
if args.do_predict:
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = raw_datasets["test"]
if args.max_test_samples is not None:
predict_examples = raw_datasets["test"]
if args.max_predict_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
predict_examples = predict_examples.select(range(args.max_predict_samples))
# Predict Feature Creation
predict_dataset = predict_examples.map(
prepare_validation_features,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
)
if args.max_test_samples is not None:
if args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(args.max_test_samples))
predict_dataset = predict_dataset.select(range(args.max_predict_samples))
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
......@@ -539,9 +539,9 @@ def main():
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
test_dataloader = DataLoader(
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
predict_dataloader = DataLoader(
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
# Post-processing:
......@@ -737,7 +737,7 @@ def main():
all_end_top_log_probs = []
all_end_top_index = []
all_cls_logits = []
for step, batch in enumerate(test_dataloader):
for step, batch in enumerate(predict_dataloader):
with torch.no_grad():
outputs = model(**batch)
start_top_log_probs = outputs.start_top_log_probs
......@@ -762,10 +762,10 @@ def main():
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
# concatenate all numpy arrays collected above
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, test_dataset, max_len)
start_top_index_concat = create_and_fill_np_array(all_start_top_index, test_dataset, max_len)
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, test_dataset, max_len)
end_top_index_concat = create_and_fill_np_array(all_end_top_index, test_dataset, max_len)
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, predict_dataset, max_len)
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
# delete the list of numpy arrays
......@@ -774,7 +774,7 @@ def main():
del end_top_log_probs
del end_top_index
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
outputs_numpy = (
start_top_log_probs_concat,
start_top_index_concat,
......@@ -783,9 +783,9 @@ def main():
cls_logits,
)
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
test_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
logger.info(f"Test metrics: {test_metric}")
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
logger.info(f"Predict metrics: {predict_metric}")
if args.output_dir is not None:
accelerator.wait_for_everyone()
......
......@@ -205,20 +205,20 @@ def parse_args():
"value if set.",
)
parser.add_argument(
"--max_val_samples",
"--max_eval_samples",
type=int,
default=None,
help="For debugging purposes or quicker training, truncate the number of validation examples to this "
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set.",
)
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--max_test_samples",
"--max_predict_samples",
type=int,
default=None,
help="For debugging purposes or quicker training, truncate the number of test examples to this",
help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
)
parser.add_argument(
"--model_type",
......@@ -486,9 +486,9 @@ def main():
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"]
if args.max_val_samples is not None:
if args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(args.max_val_samples))
eval_examples = eval_examples.select(range(args.max_eval_samples))
# Validation Feature Creation
eval_dataset = eval_examples.map(
prepare_validation_features,
......@@ -498,28 +498,28 @@ def main():
load_from_cache_file=not args.overwrite_cache,
)
if args.max_val_samples is not None:
if args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(args.max_val_samples))
eval_dataset = eval_dataset.select(range(args.max_eval_samples))
if args.do_predict:
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = raw_datasets["test"]
if args.max_test_samples is not None:
predict_examples = raw_datasets["test"]
if args.max_predict_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
predict_examples = predict_examples.select(range(args.max_predict_samples))
# Predict Feature Creation
predict_dataset = predict_examples.map(
prepare_validation_features,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
)
if args.max_test_samples is not None:
if args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(args.max_test_samples))
predict_dataset = predict_dataset.select(range(args.max_predict_samples))
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
......@@ -544,9 +544,9 @@ def main():
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
test_dataloader = DataLoader(
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
predict_dataloader = DataLoader(
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
# Post-processing:
......@@ -714,7 +714,7 @@ def main():
if args.do_predict:
all_start_logits = []
all_end_logits = []
for step, batch in enumerate(test_dataloader):
for step, batch in enumerate(predict_dataloader):
with torch.no_grad():
outputs = model(**batch)
start_logits = outputs.start_logits
......@@ -729,19 +729,19 @@ def main():
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
# concatenate the numpy array
start_logits_concat = create_and_fill_np_array(all_start_logits, test_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, test_dataset, max_len)
start_logits_concat = create_and_fill_np_array(all_start_logits, predict_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, predict_dataset, max_len)
# delete the list of numpy arrays
del all_start_logits
del all_end_logits
# Now we need to add extra columns which we removed for post processing
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
logger.info(f"Test metrics: {eval_metric}")
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
logger.info(f"Predict metrics: {predict_metric}")
if args.output_dir is not None:
accelerator.wait_for_everyone()
......
......@@ -66,16 +66,16 @@ class QuestionAnsweringTrainer(Trainer):
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
def predict(self, test_dataset, test_examples, ignore_keys=None):
test_dataloader = self.get_test_dataloader(test_dataset)
def predict(self, predict_dataset, predict_examples, ignore_keys=None):
predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
try:
output = self.prediction_loop(
test_dataloader,
description="Evaluation",
predict_dataloader,
description="Prediction",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
......@@ -87,7 +87,7 @@ class QuestionAnsweringTrainer(Trainer):
if self.post_process_function is None or self.compute_metrics is None:
return output
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
metrics = self.compute_metrics(eval_preds)
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(predictions)
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
......@@ -178,17 +178,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -438,8 +438,8 @@ def main():
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
......@@ -452,10 +452,10 @@ def main():
max_target_length = data_args.val_max_target_length
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
predict_dataset = datasets["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -547,37 +547,39 @@ def main():
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
logger.info("*** Predict ***")
test_results = trainer.predict(
test_dataset,
metric_key_prefix="test",
predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
metrics = test_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
predictions = tokenizer.batch_decode(
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
test_preds = [pred.strip() for pred in test_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
predictions = [pred.strip() for pred in predictions]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
with open(output_prediction_file, "w") as writer:
writer.write("\n".join(predictions))
if training_args.push_to_hub:
trainer.push_to_hub()
......
......@@ -100,17 +100,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -390,15 +390,15 @@ def main():
if "validation" not in datasets and "validation_matched" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Log a few random samples from the training set:
if training_args.do_train:
......@@ -483,32 +483,34 @@ def main():
for eval_dataset, task in zip(eval_datasets, tasks):
metrics = trainer.evaluate(eval_dataset=eval_dataset)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
logger.info("*** Predict ***")
# Loop to handle MNLI double evaluation (matched, mis-matched)
tasks = [data_args.task_name]
test_datasets = [test_dataset]
predict_datasets = [predict_dataset]
if data_args.task_name == "mnli":
tasks.append("mnli-mm")
test_datasets.append(datasets["test_mismatched"])
predict_datasets.append(datasets["test_mismatched"])
for test_dataset, task in zip(test_datasets, tasks):
for predict_dataset, task in zip(predict_datasets, tasks):
# Removing the `label` columns because it contains -1 and Trainer won't like that.
test_dataset.remove_columns_("label")
predictions = trainer.predict(test_dataset=test_dataset).predictions
predict_dataset.remove_columns_("label")
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
logger.info(f"***** Test results {task} *****")
with open(output_predict_file, "w") as writer:
logger.info(f"***** Predict results {task} *****")
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
if is_regression:
......
......@@ -84,17 +84,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -221,8 +221,8 @@ def main():
label_list = eval_dataset.features["label"].names
if training_args.do_predict:
test_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
label_list = test_dataset.features["label"].names
predict_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
label_list = predict_dataset.features["label"].names
# Labels
num_labels = len(label_list)
......@@ -286,8 +286,8 @@ def main():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_eval:
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
......@@ -295,9 +295,9 @@ def main():
)
if training_args.do_predict:
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
......@@ -360,8 +360,8 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(eval_dataset=eval_dataset)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......@@ -369,18 +369,20 @@ def main():
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
output_predict_file = os.path.join(training_args.output_dir, "predictions.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
with open(output_predict_file, "w") as writer:
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
item = label_list[item]
......
......@@ -128,17 +128,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -363,8 +363,8 @@ def main():
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
tokenize_and_align_labels,
batched=True,
......@@ -375,10 +375,10 @@ def main():
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
predict_dataset = datasets["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
tokenize_and_align_labels,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -462,8 +462,8 @@ def main():
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......@@ -472,7 +472,7 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
......@@ -481,13 +481,13 @@ def main():
for prediction, label in zip(predictions, labels)
]
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
# Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
if trainer.is_world_process_zero():
with open(output_test_predictions_file, "w") as writer:
with open(output_predictions_file, "w") as writer:
for prediction in true_predictions:
writer.write(" ".join(prediction) + "\n")
......
......@@ -167,17 +167,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -432,8 +432,8 @@ def main():
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
......@@ -446,10 +446,10 @@ def main():
max_target_length = data_args.val_max_target_length
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
predict_dataset = datasets["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -539,37 +539,39 @@ def main():
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
logger.info("*** Predict ***")
test_results = trainer.predict(
test_dataset,
metric_key_prefix="test",
predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
metrics = test_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
predictions = tokenizer.batch_decode(
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
test_preds = [pred.strip() for pred in test_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
predictions = [pred.strip() for pred in predictions]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
with open(output_prediction_file, "w") as writer:
writer.write("\n".join(predictions))
if training_args.push_to_hub:
trainer.push_to_hub()
......
......@@ -164,17 +164,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
"value if set."
},
)
......@@ -468,13 +468,13 @@ def main():
if "validation" in datasets:
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if "test" in datasets:
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
predict_dataset = datasets["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# endregion
......@@ -513,15 +513,15 @@ def main():
# region Prediction
if "test" in datasets:
logger.info("Doing predictions on test dataset...")
logger.info("Doing predictions on Predict dataset...")
test_dataset = DataSequence(
test_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
predict_dataset = DataSequence(
predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
)
predictions = model.predict(test_dataset)["logits"]
predictions = model.predict(predict_dataset)["logits"]
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_results.txt")
with open(output_test_file, "w") as writer:
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
with open(output_predict_file, "w") as writer:
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
if is_regression:
......@@ -529,7 +529,7 @@ def main():
else:
item = model.config.id2label[item]
writer.write(f"{index}\t{item}\n")
logger.info(f"Wrote predictions to {output_test_file}!")
logger.info(f"Wrote predictions to {output_predict_file}!")
# endregion
......
......@@ -157,17 +157,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
......@@ -379,8 +379,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
# Selecting samples from dataset
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# tokenize validation dataset
eval_dataset = eval_dataset.map(
tokenize_function,
......@@ -393,12 +393,12 @@ def main():
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"]
predict_dataset = datasets["test"]
# Selecting samples from dataset
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
# tokenize test dataset
test_dataset = test_dataset.map(
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# tokenize predict dataset
predict_dataset = predict_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -455,8 +455,8 @@ def main():
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......@@ -464,13 +464,13 @@ def main():
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)
predictions, labels, metrics = trainer.predict(predict_dataset)
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
max_predict_samples = data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
# write custom code for saving predictions according to task
......
......@@ -578,7 +578,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
args.extend(
"""
--do_eval
--max_val_samples 100
--max_eval_samples 100
--per_device_eval_batch_size 2
""".split()
)
......@@ -620,7 +620,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--do_train
--do_eval
--max_train_samples 10
--max_val_samples 10
--max_eval_samples 10
--per_device_train_batch_size 5
--per_device_eval_batch_size 5
--num_train_epochs 1
......
......@@ -191,7 +191,7 @@ class TestTrainerExt(TestCasePlus):
--output_dir {output_dir}
--overwrite_output_dir
--max_train_samples 8
--max_val_samples 8
--max_eval_samples 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment