Unverified Commit 19597998 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Don't compute metrics in LM examples on TPU (#16029)

parent 10591399
...@@ -43,6 +43,7 @@ from transformers import ( ...@@ -43,6 +43,7 @@ from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
default_data_collator, default_data_collator,
is_torch_tpu_available,
set_seed, set_seed,
) )
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
...@@ -479,8 +480,10 @@ def main(): ...@@ -479,8 +480,10 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it. # Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator, data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
) )
# Training # Training
......
...@@ -43,6 +43,7 @@ from transformers import ( ...@@ -43,6 +43,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
is_torch_tpu_available,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
...@@ -513,8 +514,10 @@ def main(): ...@@ -513,8 +514,10 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
) )
# Training # Training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment