Unverified Commit 1d30ec95 authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)

* added changes for uniformity

* modified files

* corrected typo

* fixed qa scripts

* fix typos

* fixed predict typo in qa no trainer

* fixed test file

* reverted trainer changes

* reverted trainer changes in custom exmaples

* updated readme

* added changes in deepspeed test

* added changes for predict and eval
parent 7959d835
...@@ -50,8 +50,8 @@ For example here is how to truncate all three splits to just 50 samples each: ...@@ -50,8 +50,8 @@ For example here is how to truncate all three splits to just 50 samples each:
``` ```
examples/pytorch/token-classification/run_ner.py \ examples/pytorch/token-classification/run_ner.py \
--max_train_samples 50 \ --max_train_samples 50 \
--max_val_samples 50 \ --max_eval_samples 50 \
--max_test_samples 50 \ --max_predict_samples 50 \
[...] [...]
``` ```
......
...@@ -126,10 +126,10 @@ class DataTrainingArguments: ...@@ -126,10 +126,10 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
...@@ -397,8 +397,8 @@ def main(): ...@@ -397,8 +397,8 @@ def main():
if "validation" not in tokenized_datasets: if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"] eval_dataset = lm_datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = Trainer(
...@@ -439,8 +439,8 @@ def main(): ...@@ -439,8 +439,8 @@ def main():
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"]) perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
......
...@@ -157,10 +157,10 @@ class DataTrainingArguments: ...@@ -157,10 +157,10 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
...@@ -419,8 +419,8 @@ def main(): ...@@ -419,8 +419,8 @@ def main():
if "validation" not in tokenized_datasets: if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"] eval_dataset = tokenized_datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.
...@@ -468,8 +468,8 @@ def main(): ...@@ -468,8 +468,8 @@ def main():
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"]) perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
......
...@@ -154,10 +154,10 @@ class DataTrainingArguments: ...@@ -154,10 +154,10 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
...@@ -397,8 +397,8 @@ def main(): ...@@ -397,8 +397,8 @@ def main():
if "validation" not in tokenized_datasets: if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"] eval_dataset = tokenized_datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Data collator # Data collator
data_collator = DataCollatorForPermutationLanguageModeling( data_collator = DataCollatorForPermutationLanguageModeling(
...@@ -444,8 +444,8 @@ def main(): ...@@ -444,8 +444,8 @@ def main():
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"]) perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
......
...@@ -127,10 +127,10 @@ class DataTrainingArguments: ...@@ -127,10 +127,10 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
...@@ -363,8 +363,8 @@ def main(): ...@@ -363,8 +363,8 @@ def main():
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
...@@ -422,8 +422,8 @@ def main(): ...@@ -422,8 +422,8 @@ def main():
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
......
...@@ -133,17 +133,17 @@ class DataTrainingArguments: ...@@ -133,17 +133,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -468,9 +468,9 @@ def main(): ...@@ -468,9 +468,9 @@ def main():
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_examples = datasets["validation"] eval_examples = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
# We will select sample from whole data # We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_val_samples)) eval_examples = eval_examples.select(range(data_args.max_eval_samples))
# Validation Feature Creation # Validation Feature Creation
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
...@@ -479,28 +479,28 @@ def main(): ...@@ -479,28 +479,28 @@ def main():
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict: if training_args.do_predict:
if "test" not in datasets: if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"] predict_examples = datasets["test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
# We will select sample from whole data # We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples)) predict_examples = predict_examples.select(range(data_args.max_predict_samples))
# Test Feature Creation # Predict Feature Creation
test_dataset = test_examples.map( predict_dataset = predict_examples.map(
prepare_validation_features, prepare_validation_features,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Data collator # Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
...@@ -581,8 +581,8 @@ def main(): ...@@ -581,8 +581,8 @@ def main():
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
...@@ -590,14 +590,16 @@ def main(): ...@@ -590,14 +590,16 @@ def main():
# Prediction # Prediction
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples) results = trainer.predict(predict_dataset, predict_examples)
metrics = results.metrics metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_predict_samples = (
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub() trainer.push_to_hub()
......
...@@ -132,17 +132,17 @@ class DataTrainingArguments: ...@@ -132,17 +132,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -504,9 +504,9 @@ def main(): ...@@ -504,9 +504,9 @@ def main():
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_examples = datasets["validation"] eval_examples = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
# Selecting Eval Samples from Dataset # Selecting Eval Samples from Dataset
eval_examples = eval_examples.select(range(data_args.max_val_samples)) eval_examples = eval_examples.select(range(data_args.max_eval_samples))
# Create Features from Eval Dataset # Create Features from Eval Dataset
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
...@@ -515,28 +515,28 @@ def main(): ...@@ -515,28 +515,28 @@ def main():
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
# Selecting Samples from Dataset again since Feature Creation might increase samples size # Selecting Samples from Dataset again since Feature Creation might increase samples size
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict: if training_args.do_predict:
if "test" not in datasets: if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"] predict_examples = datasets["test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
# We will select sample from whole data # We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples)) predict_examples = predict_examples.select(range(data_args.max_predict_samples))
# Test Feature Creation # Test Feature Creation
test_dataset = test_examples.map( predict_dataset = predict_examples.map(
prepare_validation_features, prepare_validation_features,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Data collator # Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
...@@ -620,8 +620,8 @@ def main(): ...@@ -620,8 +620,8 @@ def main():
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
...@@ -629,14 +629,16 @@ def main(): ...@@ -629,14 +629,16 @@ def main():
# Prediction # Prediction
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples) results = trainer.predict(predict_dataset, predict_examples)
metrics = results.metrics metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_predict_samples = (
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub() trainer.push_to_hub()
......
...@@ -183,20 +183,20 @@ def parse_args(): ...@@ -183,20 +183,20 @@ def parse_args():
"value if set.", "value if set.",
) )
parser.add_argument( parser.add_argument(
"--max_val_samples", "--max_eval_samples",
type=int, type=int,
default=None, default=None,
help="For debugging purposes or quicker training, truncate the number of validation examples to this " help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set.", "value if set.",
) )
parser.add_argument( parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
) )
parser.add_argument( parser.add_argument(
"--max_test_samples", "--max_predict_samples",
type=int, type=int,
default=None, default=None,
help="For debugging purposes or quicker training, truncate the number of test examples to this", help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -481,9 +481,9 @@ def main(): ...@@ -481,9 +481,9 @@ def main():
if "validation" not in raw_datasets: if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"] eval_examples = raw_datasets["validation"]
if args.max_val_samples is not None: if args.max_eval_samples is not None:
# We will select sample from whole data # We will select sample from whole data
eval_examples = eval_examples.select(range(args.max_val_samples)) eval_examples = eval_examples.select(range(args.max_eval_samples))
# Validation Feature Creation # Validation Feature Creation
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
...@@ -493,28 +493,28 @@ def main(): ...@@ -493,28 +493,28 @@ def main():
load_from_cache_file=not args.overwrite_cache, load_from_cache_file=not args.overwrite_cache,
) )
if args.max_val_samples is not None: if args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(args.max_val_samples)) eval_dataset = eval_dataset.select(range(args.max_eval_samples))
if args.do_predict: if args.do_predict:
if "test" not in raw_datasets: if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_examples = raw_datasets["test"] predict_examples = raw_datasets["test"]
if args.max_test_samples is not None: if args.max_predict_samples is not None:
# We will select sample from whole data # We will select sample from whole data
test_examples = test_examples.select(range(args.max_test_samples)) predict_examples = predict_examples.select(range(args.max_predict_samples))
# Test Feature Creation # Predict Feature Creation
test_dataset = test_examples.map( predict_dataset = predict_examples.map(
prepare_validation_features, prepare_validation_features,
batched=True, batched=True,
num_proc=args.preprocessing_num_workers, num_proc=args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache, load_from_cache_file=not args.overwrite_cache,
) )
if args.max_test_samples is not None: if args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(args.max_test_samples)) predict_dataset = predict_dataset.select(range(args.max_predict_samples))
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
...@@ -539,9 +539,9 @@ def main(): ...@@ -539,9 +539,9 @@ def main():
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
if args.do_predict: if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
test_dataloader = DataLoader( predict_dataloader = DataLoader(
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
) )
# Post-processing: # Post-processing:
...@@ -737,7 +737,7 @@ def main(): ...@@ -737,7 +737,7 @@ def main():
all_end_top_log_probs = [] all_end_top_log_probs = []
all_end_top_index = [] all_end_top_index = []
all_cls_logits = [] all_cls_logits = []
for step, batch in enumerate(test_dataloader): for step, batch in enumerate(predict_dataloader):
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
start_top_log_probs = outputs.start_top_log_probs start_top_log_probs = outputs.start_top_log_probs
...@@ -762,10 +762,10 @@ def main(): ...@@ -762,10 +762,10 @@ def main():
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
# concatenate all numpy arrays collected above # concatenate all numpy arrays collected above
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, test_dataset, max_len) start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, predict_dataset, max_len)
start_top_index_concat = create_and_fill_np_array(all_start_top_index, test_dataset, max_len) start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, test_dataset, max_len) end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
end_top_index_concat = create_and_fill_np_array(all_end_top_index, test_dataset, max_len) end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
all_cls_logits = np.concatenate(all_cls_logits, axis=0) all_cls_logits = np.concatenate(all_cls_logits, axis=0)
# delete the list of numpy arrays # delete the list of numpy arrays
...@@ -774,7 +774,7 @@ def main(): ...@@ -774,7 +774,7 @@ def main():
del end_top_log_probs del end_top_log_probs
del end_top_index del end_top_index
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys())) predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
outputs_numpy = ( outputs_numpy = (
start_top_log_probs_concat, start_top_log_probs_concat,
start_top_index_concat, start_top_index_concat,
...@@ -783,9 +783,9 @@ def main(): ...@@ -783,9 +783,9 @@ def main():
cls_logits, cls_logits,
) )
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy) prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
test_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids) predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
logger.info(f"Test metrics: {test_metric}") logger.info(f"Predict metrics: {predict_metric}")
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -205,20 +205,20 @@ def parse_args(): ...@@ -205,20 +205,20 @@ def parse_args():
"value if set.", "value if set.",
) )
parser.add_argument( parser.add_argument(
"--max_val_samples", "--max_eval_samples",
type=int, type=int,
default=None, default=None,
help="For debugging purposes or quicker training, truncate the number of validation examples to this " help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set.", "value if set.",
) )
parser.add_argument( parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
) )
parser.add_argument( parser.add_argument(
"--max_test_samples", "--max_predict_samples",
type=int, type=int,
default=None, default=None,
help="For debugging purposes or quicker training, truncate the number of test examples to this", help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
) )
parser.add_argument( parser.add_argument(
"--model_type", "--model_type",
...@@ -486,9 +486,9 @@ def main(): ...@@ -486,9 +486,9 @@ def main():
if "validation" not in raw_datasets: if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"] eval_examples = raw_datasets["validation"]
if args.max_val_samples is not None: if args.max_eval_samples is not None:
# We will select sample from whole data # We will select sample from whole data
eval_examples = eval_examples.select(range(args.max_val_samples)) eval_examples = eval_examples.select(range(args.max_eval_samples))
# Validation Feature Creation # Validation Feature Creation
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
...@@ -498,28 +498,28 @@ def main(): ...@@ -498,28 +498,28 @@ def main():
load_from_cache_file=not args.overwrite_cache, load_from_cache_file=not args.overwrite_cache,
) )
if args.max_val_samples is not None: if args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(args.max_val_samples)) eval_dataset = eval_dataset.select(range(args.max_eval_samples))
if args.do_predict: if args.do_predict:
if "test" not in raw_datasets: if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_examples = raw_datasets["test"] predict_examples = raw_datasets["test"]
if args.max_test_samples is not None: if args.max_predict_samples is not None:
# We will select sample from whole data # We will select sample from whole data
test_examples = test_examples.select(range(args.max_test_samples)) predict_examples = predict_examples.select(range(args.max_predict_samples))
# Test Feature Creation # Predict Feature Creation
test_dataset = test_examples.map( predict_dataset = predict_examples.map(
prepare_validation_features, prepare_validation_features,
batched=True, batched=True,
num_proc=args.preprocessing_num_workers, num_proc=args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache, load_from_cache_file=not args.overwrite_cache,
) )
if args.max_test_samples is not None: if args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(args.max_test_samples)) predict_dataset = predict_dataset.select(range(args.max_predict_samples))
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
...@@ -544,9 +544,9 @@ def main(): ...@@ -544,9 +544,9 @@ def main():
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
if args.do_predict: if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
test_dataloader = DataLoader( predict_dataloader = DataLoader(
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
) )
# Post-processing: # Post-processing:
...@@ -714,7 +714,7 @@ def main(): ...@@ -714,7 +714,7 @@ def main():
if args.do_predict: if args.do_predict:
all_start_logits = [] all_start_logits = []
all_end_logits = [] all_end_logits = []
for step, batch in enumerate(test_dataloader): for step, batch in enumerate(predict_dataloader):
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
start_logits = outputs.start_logits start_logits = outputs.start_logits
...@@ -729,19 +729,19 @@ def main(): ...@@ -729,19 +729,19 @@ def main():
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
# concatenate the numpy array # concatenate the numpy array
start_logits_concat = create_and_fill_np_array(all_start_logits, test_dataset, max_len) start_logits_concat = create_and_fill_np_array(all_start_logits, predict_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, test_dataset, max_len) end_logits_concat = create_and_fill_np_array(all_end_logits, predict_dataset, max_len)
# delete the list of numpy arrays # delete the list of numpy arrays
del all_start_logits del all_start_logits
del all_end_logits del all_end_logits
# Now we need to add extra columns which we removed for post processing # Now we need to add extra columns which we removed for post processing
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys())) predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
outputs_numpy = (start_logits_concat, end_logits_concat) outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy) prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids) predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
logger.info(f"Test metrics: {eval_metric}") logger.info(f"Predict metrics: {predict_metric}")
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -66,16 +66,16 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -66,16 +66,16 @@ class QuestionAnsweringTrainer(Trainer):
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics return metrics
def predict(self, test_dataset, test_examples, ignore_keys=None): def predict(self, predict_dataset, predict_examples, ignore_keys=None):
test_dataloader = self.get_test_dataloader(test_dataset) predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here. # Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics compute_metrics = self.compute_metrics
self.compute_metrics = None self.compute_metrics = None
try: try:
output = self.prediction_loop( output = self.prediction_loop(
test_dataloader, predict_dataloader,
description="Evaluation", description="Prediction",
# No point gathering the predictions if there are no metrics, otherwise we defer to # No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None, prediction_loss_only=True if compute_metrics is None else None,
...@@ -87,7 +87,7 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -87,7 +87,7 @@ class QuestionAnsweringTrainer(Trainer):
if self.post_process_function is None or self.compute_metrics is None: if self.post_process_function is None or self.compute_metrics is None:
return output return output
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test") predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(eval_preds) metrics = self.compute_metrics(predictions)
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics) return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
...@@ -178,17 +178,17 @@ class DataTrainingArguments: ...@@ -178,17 +178,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -438,8 +438,8 @@ def main(): ...@@ -438,8 +438,8 @@ def main():
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
...@@ -452,10 +452,10 @@ def main(): ...@@ -452,10 +452,10 @@ def main():
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
if "test" not in datasets: if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"] predict_dataset = datasets["test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
test_dataset = test_dataset.map( predict_dataset = predict_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -547,37 +547,39 @@ def main(): ...@@ -547,37 +547,39 @@ def main():
metrics = trainer.evaluate( metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval" max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
) )
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Test ***") logger.info("*** Predict ***")
test_results = trainer.predict( predict_results = trainer.predict(
test_dataset, predict_dataset,
metric_key_prefix="test", metric_key_prefix="predict",
max_length=data_args.val_max_target_length, max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams, num_beams=data_args.num_beams,
) )
metrics = test_results.metrics metrics = predict_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_predict_samples = (
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
if training_args.predict_with_generate: if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode( predictions = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
) )
test_preds = [pred.strip() for pred in test_preds] predictions = [pred.strip() for pred in predictions]
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt") output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
with open(output_test_preds_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write("\n".join(test_preds)) writer.write("\n".join(predictions))
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub() trainer.push_to_hub()
......
...@@ -100,17 +100,17 @@ class DataTrainingArguments: ...@@ -100,17 +100,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -390,15 +390,15 @@ def main(): ...@@ -390,15 +390,15 @@ def main():
if "validation" not in datasets and "validation_matched" not in datasets: if "validation" not in datasets and "validation_matched" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets: if "test" not in datasets and "test_matched" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Log a few random samples from the training set: # Log a few random samples from the training set:
if training_args.do_train: if training_args.do_train:
...@@ -483,32 +483,34 @@ def main(): ...@@ -483,32 +483,34 @@ def main():
for eval_dataset, task in zip(eval_datasets, tasks): for eval_dataset, task in zip(eval_datasets, tasks):
metrics = trainer.evaluate(eval_dataset=eval_dataset) metrics = trainer.evaluate(eval_dataset=eval_dataset)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = (
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Test ***") logger.info("*** Predict ***")
# Loop to handle MNLI double evaluation (matched, mis-matched) # Loop to handle MNLI double evaluation (matched, mis-matched)
tasks = [data_args.task_name] tasks = [data_args.task_name]
test_datasets = [test_dataset] predict_datasets = [predict_dataset]
if data_args.task_name == "mnli": if data_args.task_name == "mnli":
tasks.append("mnli-mm") tasks.append("mnli-mm")
test_datasets.append(datasets["test_mismatched"]) predict_datasets.append(datasets["test_mismatched"])
for test_dataset, task in zip(test_datasets, tasks): for predict_dataset, task in zip(predict_datasets, tasks):
# Removing the `label` columns because it contains -1 and Trainer won't like that. # Removing the `label` columns because it contains -1 and Trainer won't like that.
test_dataset.remove_columns_("label") predict_dataset.remove_columns_("label")
predictions = trainer.predict(test_dataset=test_dataset).predictions predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer: with open(output_predict_file, "w") as writer:
logger.info(f"***** Test results {task} *****") logger.info(f"***** Predict results {task} *****")
writer.write("index\tprediction\n") writer.write("index\tprediction\n")
for index, item in enumerate(predictions): for index, item in enumerate(predictions):
if is_regression: if is_regression:
......
...@@ -84,17 +84,17 @@ class DataTrainingArguments: ...@@ -84,17 +84,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -221,8 +221,8 @@ def main(): ...@@ -221,8 +221,8 @@ def main():
label_list = eval_dataset.features["label"].names label_list = eval_dataset.features["label"].names
if training_args.do_predict: if training_args.do_predict:
test_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir) predict_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
label_list = test_dataset.features["label"].names label_list = predict_dataset.features["label"].names
# Labels # Labels
num_labels = len(label_list) num_labels = len(label_list)
...@@ -286,8 +286,8 @@ def main(): ...@@ -286,8 +286,8 @@ def main():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_eval: if training_args.do_eval:
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
...@@ -295,9 +295,9 @@ def main(): ...@@ -295,9 +295,9 @@ def main():
) )
if training_args.do_predict: if training_args.do_predict:
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
test_dataset = test_dataset.map( predict_dataset = predict_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
...@@ -360,8 +360,8 @@ def main(): ...@@ -360,8 +360,8 @@ def main():
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate(eval_dataset=eval_dataset) metrics = trainer.evaluate(eval_dataset=eval_dataset)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
...@@ -369,18 +369,20 @@ def main(): ...@@ -369,18 +369,20 @@ def main():
# Prediction # Prediction
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset) predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_predict_samples = (
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
predictions = np.argmax(predictions, axis=1) predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt") output_predict_file = os.path.join(training_args.output_dir, "predictions.txt")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer: with open(output_predict_file, "w") as writer:
writer.write("index\tprediction\n") writer.write("index\tprediction\n")
for index, item in enumerate(predictions): for index, item in enumerate(predictions):
item = label_list[item] item = label_list[item]
......
...@@ -128,17 +128,17 @@ class DataTrainingArguments: ...@@ -128,17 +128,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -363,8 +363,8 @@ def main(): ...@@ -363,8 +363,8 @@ def main():
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
tokenize_and_align_labels, tokenize_and_align_labels,
batched=True, batched=True,
...@@ -375,10 +375,10 @@ def main(): ...@@ -375,10 +375,10 @@ def main():
if training_args.do_predict: if training_args.do_predict:
if "test" not in datasets: if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"] predict_dataset = datasets["test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
test_dataset = test_dataset.map( predict_dataset = predict_dataset.map(
tokenize_and_align_labels, tokenize_and_align_labels,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -462,8 +462,8 @@ def main(): ...@@ -462,8 +462,8 @@ def main():
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
...@@ -472,7 +472,7 @@ def main(): ...@@ -472,7 +472,7 @@ def main():
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset) predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
predictions = np.argmax(predictions, axis=2) predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens) # Remove ignored index (special tokens)
...@@ -481,13 +481,13 @@ def main(): ...@@ -481,13 +481,13 @@ def main():
for prediction, label in zip(predictions, labels) for prediction, label in zip(predictions, labels)
] ]
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
# Save predictions # Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
with open(output_test_predictions_file, "w") as writer: with open(output_predictions_file, "w") as writer:
for prediction in true_predictions: for prediction in true_predictions:
writer.write(" ".join(prediction) + "\n") writer.write(" ".join(prediction) + "\n")
......
...@@ -167,17 +167,17 @@ class DataTrainingArguments: ...@@ -167,17 +167,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -432,8 +432,8 @@ def main(): ...@@ -432,8 +432,8 @@ def main():
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
...@@ -446,10 +446,10 @@ def main(): ...@@ -446,10 +446,10 @@ def main():
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
if "test" not in datasets: if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"] predict_dataset = datasets["test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
test_dataset = test_dataset.map( predict_dataset = predict_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -539,37 +539,39 @@ def main(): ...@@ -539,37 +539,39 @@ def main():
metrics = trainer.evaluate( metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval" max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
) )
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Test ***") logger.info("*** Predict ***")
test_results = trainer.predict( predict_results = trainer.predict(
test_dataset, predict_dataset,
metric_key_prefix="test", metric_key_prefix="predict",
max_length=data_args.val_max_target_length, max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams, num_beams=data_args.num_beams,
) )
metrics = test_results.metrics metrics = predict_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_predict_samples = (
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
if training_args.predict_with_generate: if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode( predictions = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
) )
test_preds = [pred.strip() for pred in test_preds] predictions = [pred.strip() for pred in predictions]
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt") output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
with open(output_test_preds_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write("\n".join(test_preds)) writer.write("\n".join(predictions))
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub() trainer.push_to_hub()
......
...@@ -164,17 +164,17 @@ class DataTrainingArguments: ...@@ -164,17 +164,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
"value if set." "value if set."
}, },
) )
...@@ -468,13 +468,13 @@ def main(): ...@@ -468,13 +468,13 @@ def main():
if "validation" in datasets: if "validation" in datasets:
eval_dataset = datasets["validation"] eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if "test" in datasets: if "test" in datasets:
test_dataset = datasets["test"] predict_dataset = datasets["test"]
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# endregion # endregion
...@@ -513,15 +513,15 @@ def main(): ...@@ -513,15 +513,15 @@ def main():
# region Prediction # region Prediction
if "test" in datasets: if "test" in datasets:
logger.info("Doing predictions on test dataset...") logger.info("Doing predictions on Predict dataset...")
test_dataset = DataSequence( predict_dataset = DataSequence(
test_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
) )
predictions = model.predict(test_dataset)["logits"] predictions = model.predict(predict_dataset)["logits"]
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_results.txt") output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
with open(output_test_file, "w") as writer: with open(output_predict_file, "w") as writer:
writer.write("index\tprediction\n") writer.write("index\tprediction\n")
for index, item in enumerate(predictions): for index, item in enumerate(predictions):
if is_regression: if is_regression:
...@@ -529,7 +529,7 @@ def main(): ...@@ -529,7 +529,7 @@ def main():
else: else:
item = model.config.id2label[item] item = model.config.id2label[item]
writer.write(f"{index}\t{item}\n") writer.write(f"{index}\t{item}\n")
logger.info(f"Wrote predictions to {output_test_file}!") logger.info(f"Wrote predictions to {output_predict_file}!")
# endregion # endregion
......
...@@ -157,17 +157,17 @@ class DataTrainingArguments: ...@@ -157,17 +157,17 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_val_samples: Optional[int] = field( max_eval_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field( max_predict_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this " "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set." "value if set."
}, },
) )
...@@ -379,8 +379,8 @@ def main(): ...@@ -379,8 +379,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_dataset = datasets["validation"]
# Selecting samples from dataset # Selecting samples from dataset
if data_args.max_val_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# tokenize validation dataset # tokenize validation dataset
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
tokenize_function, tokenize_function,
...@@ -393,12 +393,12 @@ def main(): ...@@ -393,12 +393,12 @@ def main():
if training_args.do_predict: if training_args.do_predict:
if "test" not in datasets: if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset") raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"] predict_dataset = datasets["test"]
# Selecting samples from dataset # Selecting samples from dataset
if data_args.max_test_samples is not None: if data_args.max_predict_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# tokenize test dataset # tokenize predict dataset
test_dataset = test_dataset.map( predict_dataset = predict_dataset.map(
tokenize_function, tokenize_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -455,8 +455,8 @@ def main(): ...@@ -455,8 +455,8 @@ def main():
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
...@@ -464,13 +464,13 @@ def main(): ...@@ -464,13 +464,13 @@ def main():
# Prediction # Prediction
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset) predictions, labels, metrics = trainer.predict(predict_dataset)
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_predict_samples = data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("predict", metrics)
# write custom code for saving predictions according to task # write custom code for saving predictions according to task
......
...@@ -578,7 +578,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): ...@@ -578,7 +578,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
args.extend( args.extend(
""" """
--do_eval --do_eval
--max_val_samples 100 --max_eval_samples 100
--per_device_eval_batch_size 2 --per_device_eval_batch_size 2
""".split() """.split()
) )
...@@ -620,7 +620,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): ...@@ -620,7 +620,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--do_train --do_train
--do_eval --do_eval
--max_train_samples 10 --max_train_samples 10
--max_val_samples 10 --max_eval_samples 10
--per_device_train_batch_size 5 --per_device_train_batch_size 5
--per_device_eval_batch_size 5 --per_device_eval_batch_size 5
--num_train_epochs 1 --num_train_epochs 1
......
...@@ -191,7 +191,7 @@ class TestTrainerExt(TestCasePlus): ...@@ -191,7 +191,7 @@ class TestTrainerExt(TestCasePlus):
--output_dir {output_dir} --output_dir {output_dir}
--overwrite_output_dir --overwrite_output_dir
--max_train_samples 8 --max_train_samples 8
--max_val_samples 8 --max_eval_samples 8
--max_source_length {max_len} --max_source_length {max_len}
--max_target_length {max_len} --max_target_length {max_len}
--val_max_target_length {max_len} --val_max_target_length {max_len}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment