Unverified Commit 1762ded3 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix metric calculation in examples and setup tests to run on multi-gpu for...

Fix metric calculation in examples and setup tests to run on multi-gpu for no_trainer scripts (#17331)

* Fix length in no_trainer examples

* Add setup and teardown

* Use new accelerator config generator to automatically make tests able to run based on environment
parent 6e195eb9
...@@ -489,7 +489,7 @@ def main(): ...@@ -489,7 +489,7 @@ def main():
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen] references = references[: len(eval_dataloader.dataset) - samples_seen]
else: else:
......
...@@ -574,7 +574,7 @@ def main(): ...@@ -574,7 +574,7 @@ def main():
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen] references = references[: len(eval_dataloader.dataset) - samples_seen]
else: else:
......
...@@ -591,7 +591,7 @@ def main(): ...@@ -591,7 +591,7 @@ def main():
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen] references = references[: len(eval_dataloader.dataset) - samples_seen]
else: else:
......
...@@ -310,7 +310,9 @@ def parse_args(): ...@@ -310,7 +310,9 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
if args.source_prefix is None and args.model_name_or_path in [ if args.source_prefix is None and args.model_name_or_path in [
"t5-small", "t5-small",
"t5-base", "t5-base",
...@@ -322,9 +324,6 @@ def main(): ...@@ -322,9 +324,6 @@ def main():
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `" "`--source_prefix 'summarize: ' `"
) )
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -675,11 +674,11 @@ def main(): ...@@ -675,11 +674,11 @@ def main():
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen] decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen] decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else: else:
samples_seen += decoded_labels.shape[0] samples_seen += len(decoded_labels)
metric.add_batch( metric.add_batch(
predictions=decoded_preds, predictions=decoded_preds,
......
...@@ -18,49 +18,18 @@ import argparse ...@@ -18,49 +18,18 @@ import argparse
import json import json
import logging import logging
import os import os
import shutil
import subprocess
import sys import sys
from unittest.mock import patch import tempfile
import torch import torch
from accelerate.utils import write_basic_config
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
from transformers.utils import is_apex_available from transformers.utils import is_apex_available
SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"multiple-choice",
"question-answering",
"summarization",
"translation",
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
"image-pretraining",
"semantic-segmentation",
]
]
sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_no_trainer
import run_glue_no_trainer
import run_image_classification_no_trainer
import run_mlm_no_trainer
import run_ner_no_trainer
import run_qa_no_trainer as run_squad_no_trainer
import run_semantic_segmentation_no_trainer
import run_summarization_no_trainer
import run_swag_no_trainer
import run_translation_no_trainer
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
...@@ -94,10 +63,22 @@ logger.addHandler(stream_handler) ...@@ -94,10 +63,22 @@ logger.addHandler(stream_handler)
class ExamplesTestsNoTrainer(TestCasePlus): class ExamplesTestsNoTrainer(TestCasePlus):
@classmethod
def setUpClass(cls):
# Write Accelerate config, will pick up on CPU, GPU, and multi-GPU
cls.tmpdir = tempfile.mkdtemp()
cls.configPath = os.path.join(cls.tmpdir, "default_config.yml")
write_basic_config(save_location=cls.configPath)
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdir)
def test_run_glue_no_trainer(self): def test_run_glue_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_glue_no_trainer.py {self.examples_dir}/pytorch/text-classification/run_glue_no_trainer.py
--model_name_or_path distilbert-base-uncased --model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir} --output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
...@@ -113,8 +94,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -113,8 +94,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
if is_cuda_and_apex_available(): if is_cuda_and_apex_available():
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_glue_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75) self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
...@@ -123,7 +103,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -123,7 +103,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_clm_no_trainer(self): def test_run_clm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_clm_no_trainer.py {self.examples_dir}/pytorch/language-modeling/run_clm_no_trainer.py
--model_name_or_path distilgpt2 --model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt --train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt --validation_file ./tests/fixtures/sample_text.txt
...@@ -140,8 +120,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -140,8 +120,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
# Skipping because there are not enough batches to train the model + would need a drop_last to work. # Skipping because there are not enough batches to train the model + would need a drop_last to work.
return return
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_clm_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 100) self.assertLess(result["perplexity"], 100)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
...@@ -150,7 +129,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -150,7 +129,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_mlm_no_trainer(self): def test_run_mlm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_mlm_no_trainer.py {self.examples_dir}/pytorch/language-modeling/run_mlm_no_trainer.py
--model_name_or_path distilroberta-base --model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt --train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt --validation_file ./tests/fixtures/sample_text.txt
...@@ -160,8 +139,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -160,8 +139,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_mlm_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 42) self.assertLess(result["perplexity"], 42)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
...@@ -173,7 +151,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -173,7 +151,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_ner_no_trainer.py {self.examples_dir}/pytorch/token-classification/run_ner_no_trainer.py
--model_name_or_path bert-base-uncased --model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json --train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json --validation_file tests/fixtures/tests_samples/conll/sample.json
...@@ -187,8 +165,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -187,8 +165,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_ner_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75) self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertLess(result["train_loss"], 0.5) self.assertLess(result["train_loss"], 0.5)
...@@ -198,7 +175,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -198,7 +175,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_squad_no_trainer(self): def test_run_squad_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_qa_no_trainer.py {self.examples_dir}/pytorch/question-answering/run_qa_no_trainer.py
--model_name_or_path bert-base-uncased --model_name_or_path bert-base-uncased
--version_2_with_negative --version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json --train_file tests/fixtures/tests_samples/SQUAD/sample.json
...@@ -213,8 +190,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -213,8 +190,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_squad_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics. # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
self.assertGreaterEqual(result["eval_f1"], 30) self.assertGreaterEqual(result["eval_f1"], 30)
...@@ -225,7 +201,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -225,7 +201,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_swag_no_trainer(self): def test_run_swag_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_swag_no_trainer.py {self.examples_dir}/pytorch/multiple-choice/run_swag_no_trainer.py
--model_name_or_path bert-base-uncased --model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/swag/sample.json --train_file tests/fixtures/tests_samples/swag/sample.json
--validation_file tests/fixtures/tests_samples/swag/sample.json --validation_file tests/fixtures/tests_samples/swag/sample.json
...@@ -238,8 +214,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -238,8 +214,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_swag_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8) self.assertGreaterEqual(result["eval_accuracy"], 0.8)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
...@@ -248,7 +223,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -248,7 +223,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_summarization_no_trainer(self): def test_run_summarization_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_summarization_no_trainer.py {self.examples_dir}/pytorch/summarization/run_summarization_no_trainer.py
--model_name_or_path t5-small --model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json --train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json --validation_file tests/fixtures/tests_samples/xsum/sample.json
...@@ -262,8 +237,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -262,8 +237,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_summarization_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10) self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2) self.assertGreaterEqual(result["eval_rouge2"], 2)
...@@ -276,7 +250,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -276,7 +250,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_translation_no_trainer(self): def test_run_translation_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_translation_no_trainer.py {self.examples_dir}/pytorch/translation/run_translation_no_trainer.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1 --model_name_or_path sshleifer/student_marian_en_ro_6_1
--source_lang en --source_lang en
--target_lang ro --target_lang ro
...@@ -294,8 +268,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -294,8 +268,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_translation_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30) self.assertGreaterEqual(result["eval_bleu"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
...@@ -308,7 +281,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -308,7 +281,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_semantic_segmentation_no_trainer.py {self.examples_dir}/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
--dataset_name huggingface/semantic-segmentation-test-sample --dataset_name huggingface/semantic-segmentation-test-sample
--output_dir {tmp_dir} --output_dir {tmp_dir}
--max_train_steps=10 --max_train_steps=10
...@@ -319,15 +292,14 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -319,15 +292,14 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--checkpointing_steps epoch --checkpointing_steps epoch
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_semantic_segmentation_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10) self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
def test_run_image_classification_no_trainer(self): def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_image_classification_no_trainer.py {self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
--dataset_name huggingface/image-classification-test-sample --dataset_name huggingface/image-classification-test-sample
--output_dir {tmp_dir} --output_dir {tmp_dir}
--num_warmup_steps=8 --num_warmup_steps=8
...@@ -339,8 +311,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -339,8 +311,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
with patch.object(sys, "argv", testargs): _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
run_image_classification_no_trainer.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.50) self.assertGreaterEqual(result["eval_accuracy"], 0.50)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
......
...@@ -528,7 +528,7 @@ def main(): ...@@ -528,7 +528,7 @@ def main():
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen] references = references[: len(eval_dataloader.dataset) - samples_seen]
else: else:
......
...@@ -683,7 +683,7 @@ def main(): ...@@ -683,7 +683,7 @@ def main():
predictions_gathered, labels_gathered = accelerator.gather((predictions, labels)) predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen] predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen] labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
else: else:
......
...@@ -661,11 +661,11 @@ def main(): ...@@ -661,11 +661,11 @@ def main():
# If we are in a multiprocess environment, the last batch has duplicates # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
if step == len(eval_dataloader): if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen] decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen] decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else: else:
samples_seen += decoded_labels.shape[0] samples_seen += len(decoded_labels)
metric.add_batch(predictions=decoded_preds, references=decoded_labels) metric.add_batch(predictions=decoded_preds, references=decoded_labels)
eval_metric = metric.compute() eval_metric = metric.compute()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment