Unverified Commit b842d727 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix deepspeed tests (#15881)

* fix deepspeed tests

* style

* more fixes
parent 6ccfa217
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
from copy import deepcopy from copy import deepcopy
from parameterized import parameterized from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
...@@ -39,11 +40,13 @@ from transformers.testing_utils import ( ...@@ -39,11 +40,13 @@ from transformers.testing_utils import (
) )
from transformers.trainer_utils import get_last_checkpoint, set_seed from transformers.trainer_utils import get_last_checkpoint, set_seed
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available(): if is_torch_available():
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig,
RegressionPreTrainedModel,
get_regression_trainer,
)
set_seed(42) set_seed(42)
......
...@@ -18,6 +18,7 @@ import subprocess ...@@ -18,6 +18,7 @@ import subprocess
from os.path import dirname from os.path import dirname
from parameterized import parameterized from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
...@@ -29,11 +30,13 @@ from transformers.testing_utils import ( ...@@ -29,11 +30,13 @@ from transformers.testing_utils import (
) )
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available(): if is_torch_available():
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig,
RegressionPreTrainedModel,
get_regression_trainer,
)
set_seed(42) set_seed(42)
...@@ -97,8 +100,8 @@ def get_launcher(distributed=False): ...@@ -97,8 +100,8 @@ def get_launcher(distributed=False):
def make_task_cmds(): def make_task_cmds():
data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples" data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples"
data_dir_wmt = f"{FIXTURE_DIRECTORY}/wmt_en_ro" data_dir_wmt = f"{data_dir_samples}/wmt_en_ro"
data_dir_xsum = f"{FIXTURE_DIRECTORY}/xsum" data_dir_xsum = f"{data_dir_samples}/xsum"
args_main = """ args_main = """
--do_train --do_train
--max_train_samples 4 --max_train_samples 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment