"templates/vscode:/vscode.git/clone" did not exist on "3bd1fe431585e233efb4564d12d751b3174996c3"
Unverified Commit b842d727 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix deepspeed tests (#15881)

* fix deepspeed tests

* style

* more fixes
parent 6ccfa217
......@@ -20,6 +20,7 @@ import unittest
from copy import deepcopy
from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
from transformers.file_utils import WEIGHTS_NAME
......@@ -39,11 +40,13 @@ from transformers.testing_utils import (
)
from transformers.trainer_utils import get_last_checkpoint, set_seed
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available():
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig,
RegressionPreTrainedModel,
get_regression_trainer,
)
set_seed(42)
......
......@@ -18,6 +18,7 @@ import subprocess
from os.path import dirname
from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
......@@ -29,11 +30,13 @@ from transformers.testing_utils import (
)
from transformers.trainer_utils import set_seed
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available():
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig,
RegressionPreTrainedModel,
get_regression_trainer,
)
set_seed(42)
......@@ -97,8 +100,8 @@ def get_launcher(distributed=False):
def make_task_cmds():
data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples"
data_dir_wmt = f"{FIXTURE_DIRECTORY}/wmt_en_ro"
data_dir_xsum = f"{FIXTURE_DIRECTORY}/xsum"
data_dir_wmt = f"{data_dir_samples}/wmt_en_ro"
data_dir_xsum = f"{data_dir_samples}/xsum"
args_main = """
--do_train
--max_train_samples 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment