Unverified Commit f0435f5a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

these should run fine on multi-gpu (#8582)

parent 36a19915
...@@ -13,7 +13,7 @@ from distillation import SummarizationDistiller, distill_main ...@@ -13,7 +13,7 @@ from distillation import SummarizationDistiller, distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from transformers import MarianMTModel from transformers import MarianMTModel
from transformers.file_utils import cached_path from transformers.file_utils import cached_path
from transformers.testing_utils import TestCasePlus, require_torch_gpu, require_torch_non_multi_gpu_but_fix_me, slow from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
from utils import load_json from utils import load_json
...@@ -32,7 +32,6 @@ class TestMbartCc25Enro(TestCasePlus): ...@@ -32,7 +32,6 @@ class TestMbartCc25Enro(TestCasePlus):
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_model_download(self): def test_model_download(self):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel.from_pretrained(MARIAN_MODEL) MarianMTModel.from_pretrained(MARIAN_MODEL)
...@@ -40,7 +39,6 @@ class TestMbartCc25Enro(TestCasePlus): ...@@ -40,7 +39,6 @@ class TestMbartCc25Enro(TestCasePlus):
# @timeout_decorator.timeout(1200) # @timeout_decorator.timeout(1200)
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_train_mbart_cc25_enro_script(self): def test_train_mbart_cc25_enro_script(self):
env_vars_to_replace = { env_vars_to_replace = {
"$MAX_LEN": 64, "$MAX_LEN": 64,
...@@ -129,7 +127,6 @@ class TestDistilMarianNoTeacher(TestCasePlus): ...@@ -129,7 +127,6 @@ class TestDistilMarianNoTeacher(TestCasePlus):
@timeout_decorator.timeout(600) @timeout_decorator.timeout(600)
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_opus_mt_distill_script(self): def test_opus_mt_distill_script(self):
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro" data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
......
...@@ -19,13 +19,7 @@ import unittest ...@@ -19,13 +19,7 @@ import unittest
from parameterized import parameterized from parameterized import parameterized
from transformers import FSMTForConditionalGeneration, FSMTTokenizer from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from transformers.testing_utils import ( from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
get_tests_dir,
require_torch,
require_torch_non_multi_gpu_but_fix_me,
slow,
torch_device,
)
from utils import calculate_bleu from utils import calculate_bleu
...@@ -54,7 +48,6 @@ class ModelEvalTester(unittest.TestCase): ...@@ -54,7 +48,6 @@ class ModelEvalTester(unittest.TestCase):
] ]
) )
@slow @slow
@require_torch_non_multi_gpu_but_fix_me
def test_bleu_scores(self, pair, min_bleu_score): def test_bleu_scores(self, pair, min_bleu_score):
# note: this test is not testing the best performance since it only evals a small batch # note: this test is not testing the best performance since it only evals a small batch
# but it should be enough to detect a regression in the output quality # but it should be enough to detect a regression in the output quality
......
...@@ -19,14 +19,7 @@ from run_eval import generate_summaries_or_translations, run_generate ...@@ -19,14 +19,7 @@ from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import ( from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
CaptureStderr,
CaptureStdout,
TestCasePlus,
require_torch_gpu,
require_torch_non_multi_gpu_but_fix_me,
slow,
)
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
...@@ -135,7 +128,6 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -135,7 +128,6 @@ class TestSummarizationDistiller(TestCasePlus):
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_hub_configs(self): def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled.""" """I put require_torch_gpu cause I only want this to run with self-scheduled."""
...@@ -153,12 +145,10 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -153,12 +145,10 @@ class TestSummarizationDistiller(TestCasePlus):
failures.append(m) failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
@require_torch_non_multi_gpu_but_fix_me
def test_distill_no_teacher(self): def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_checkpointing_with_teacher(self): def test_distill_checkpointing_with_teacher(self):
updates = dict( updates = dict(
student_encoder_layers=2, student_encoder_layers=2,
...@@ -183,7 +173,6 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -183,7 +173,6 @@ class TestSummarizationDistiller(TestCasePlus):
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new) convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin")) assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
@require_torch_non_multi_gpu_but_fix_me
def test_loss_fn(self): def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY) model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"] input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
...@@ -204,7 +193,6 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -204,7 +193,6 @@ class TestSummarizationDistiller(TestCasePlus):
# TODO: understand why this breaks # TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss) self.assertEqual(nll_loss, model_computed_loss)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_mbart(self): def test_distill_mbart(self):
updates = dict( updates = dict(
student_encoder_layers=2, student_encoder_layers=2,
...@@ -229,7 +217,6 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -229,7 +217,6 @@ class TestSummarizationDistiller(TestCasePlus):
assert len(all_files) > 2 assert len(all_files) > 2
self.assertEqual(len(transformer_ckpts), 2) self.assertEqual(len(transformer_ckpts), 2)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_t5(self): def test_distill_t5(self):
updates = dict( updates = dict(
student_encoder_layers=1, student_encoder_layers=1,
...@@ -241,7 +228,6 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -241,7 +228,6 @@ class TestSummarizationDistiller(TestCasePlus):
) )
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_different_base_models(self): def test_distill_different_base_models(self):
updates = dict( updates = dict(
teacher=T5_TINY, teacher=T5_TINY,
...@@ -321,21 +307,18 @@ class TestTheRest(TestCasePlus): ...@@ -321,21 +307,18 @@ class TestTheRest(TestCasePlus):
# test one model to quickly (no-@slow) catch simple problems and do an # test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately # extensive testing of functionality with multiple models as @slow separately
@require_torch_non_multi_gpu_but_fix_me
def test_run_eval(self): def test_run_eval(self):
self.run_eval_tester(T5_TINY) self.run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow # any extra models should go into the list here - can be slow
@parameterized.expand([BART_TINY, MBART_TINY]) @parameterized.expand([BART_TINY, MBART_TINY])
@slow @slow
@require_torch_non_multi_gpu_but_fix_me
def test_run_eval_slow(self, model): def test_run_eval_slow(self, model):
self.run_eval_tester(model) self.run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) # testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@parameterized.expand([T5_TINY, MBART_TINY]) @parameterized.expand([T5_TINY, MBART_TINY])
@slow @slow
@require_torch_non_multi_gpu_but_fix_me
def test_run_eval_search(self, model): def test_run_eval_search(self, model):
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source" input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt" output_file_name = input_file_name.parent / "utest_output.txt"
...@@ -386,7 +369,6 @@ class TestTheRest(TestCasePlus): ...@@ -386,7 +369,6 @@ class TestTheRest(TestCasePlus):
@parameterized.expand( @parameterized.expand(
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY], [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
) )
@require_torch_non_multi_gpu_but_fix_me
def test_finetune(self, model): def test_finetune(self, model):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization" task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
...@@ -438,7 +420,6 @@ class TestTheRest(TestCasePlus): ...@@ -438,7 +420,6 @@ class TestTheRest(TestCasePlus):
assert isinstance(example_batch, dict) assert isinstance(example_batch, dict)
assert len(example_batch) >= 4 assert len(example_batch) >= 4
@require_torch_non_multi_gpu_but_fix_me
def test_finetune_extra_model_args(self): def test_finetune_extra_model_args(self):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
...@@ -489,7 +470,6 @@ class TestTheRest(TestCasePlus): ...@@ -489,7 +470,6 @@ class TestTheRest(TestCasePlus):
model = main(args) model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
@require_torch_non_multi_gpu_but_fix_me
def test_finetune_lr_schedulers(self): def test_finetune_lr_schedulers(self):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment