Unverified Commit 9454f437 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] make `TestDeepSpeedModelZoo` device-agnostic (#31402)

* fix

* use accelerator device count

* ci fix
parent 7977f206
...@@ -2432,6 +2432,10 @@ if is_torch_available(): ...@@ -2432,6 +2432,10 @@ if is_torch_available():
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None} BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1} BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
else:
BACKEND_MANUAL_SEED = {"default": None}
BACKEND_EMPTY_CACHE = {"default": None}
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
def backend_manual_seed(device: str, seed: int): def backend_manual_seed(device: str, seed: int):
......
...@@ -23,12 +23,13 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa ...@@ -23,12 +23,13 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
backend_device_count,
execute_subprocess_async, execute_subprocess_async,
get_gpu_count,
get_tests_dir, get_tests_dir,
require_deepspeed, require_deepspeed,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device,
) )
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
...@@ -143,7 +144,7 @@ def get_launcher(distributed=False): ...@@ -143,7 +144,7 @@ def get_launcher(distributed=False):
# - it won't be able to handle that # - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different # 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data) # results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1 num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
master_port = os.environ.get("DS_TEST_PORT", DEFAULT_MASTER_PORT) master_port = os.environ.get("DS_TEST_PORT", DEFAULT_MASTER_PORT)
return f"deepspeed --num_nodes 1 --num_gpus {num_gpus} --master_port {master_port}".split() return f"deepspeed --num_nodes 1 --num_gpus {num_gpus} --master_port {master_port}".split()
...@@ -326,7 +327,7 @@ params = list(itertools.product(stages, task_cmds.keys())) ...@@ -326,7 +327,7 @@ params = list(itertools.product(stages, task_cmds.keys()))
@slow @slow
@require_deepspeed @require_deepspeed
@require_torch_gpu @require_torch_accelerator
class TestDeepSpeedModelZoo(TestCasePlus): class TestDeepSpeedModelZoo(TestCasePlus):
"""This class is for testing via an external script - can do multiple gpus""" """This class is for testing via an external script - can do multiple gpus"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment