Unverified Commit 63841c55 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

add tests for the new sharded ddp fairscale integration (#9177)

parent bf713cde
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
import os import os
import sys import sys
import unittest
from unittest.mock import patch from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available from transformers.file_utils import is_datasets_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
execute_subprocess_async, execute_subprocess_async,
...@@ -38,9 +40,20 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" ...@@ -38,9 +40,20 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MBART_TINY = "sshleifer/tiny-mbart" MBART_TINY = "sshleifer/tiny-mbart"
# a candidate for testing_utils
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
class TestFinetuneTrainer(TestCasePlus): class TestFinetuneTrainer(TestCasePlus):
def finetune_trainer_quick(self, distributed=None): def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed) output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0] first_step_stats = eval_metrics[0]
...@@ -59,6 +72,16 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -59,6 +72,16 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_ddp(self): def test_finetune_trainer_ddp(self):
self.finetune_trainer_quick(distributed=True) self.finetune_trainer_quick(distributed=True)
@require_torch_multi_gpu
@require_fairscale
def test_finetune_trainer_ddp_sharded_ddp(self):
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp")
@require_torch_multi_gpu
@require_fairscale
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
@slow @slow
def test_finetune_trainer_slow(self): def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere # There is a missing call to __init__process_group somewhere
...@@ -195,7 +218,13 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -195,7 +218,13 @@ class TestFinetuneTrainer(TestCasePlus):
trainer.train() trainer.train()
def run_trainer( def run_trainer(
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False self,
eval_steps: int,
max_len: str,
model_name: str,
num_train_epochs: int,
distributed: bool = False,
extra_args_str: str = None,
): ):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro" data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
...@@ -231,6 +260,9 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -231,6 +260,9 @@ class TestFinetuneTrainer(TestCasePlus):
""".split() """.split()
# --eval_beams 2 # --eval_beams 2
if extra_args_str is not None:
args.extend(extra_args_str.split())
if distributed: if distributed:
n_gpu = get_gpu_count() n_gpu = get_gpu_count()
distributed_args = f""" distributed_args = f"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment