Unverified Commit 82c7e879 authored by Hz, Ji's avatar Hz, Ji Committed by GitHub
Browse files

device agnostic fsdp testing (#27120)

* make fsdp test cases device agnostic

* make style
parent 7d8ff362
...@@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa ...@@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
backend_device_count,
execute_subprocess_async, execute_subprocess_async,
get_gpu_count,
mockenv_context, mockenv_context,
require_accelerate, require_accelerate,
require_fsdp, require_fsdp,
require_torch_gpu, require_torch_accelerator,
require_torch_multi_gpu, require_torch_multi_accelerator,
slow, slow,
torch_device,
) )
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import FSDPOption, set_seed from transformers.trainer_utils import FSDPOption, set_seed
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
if is_torch_available(): if is_torch_available():
...@@ -46,7 +47,7 @@ else: ...@@ -46,7 +47,7 @@ else:
# default torch.distributed port # default torch.distributed port
DEFAULT_MASTER_PORT = "10999" DEFAULT_MASTER_PORT = "10999"
dtypes = ["fp16"] dtypes = ["fp16"]
if is_torch_bf16_gpu_available(): if is_torch_bf16_available_on_device(torch_device):
dtypes += ["bf16"] dtypes += ["bf16"]
sharding_strategies = ["full_shard", "shard_grad_op"] sharding_strategies = ["full_shard", "shard_grad_op"]
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"] state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
...@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False): ...@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
# - it won't be able to handle that # - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different # 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data) # results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1 num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
master_port = get_master_port(real_launcher=True) master_port = get_master_port(real_launcher=True)
if use_accelerate: if use_accelerate:
return f"""accelerate launch return f"""accelerate launch
...@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param): ...@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
@require_accelerate @require_accelerate
@require_torch_gpu @require_torch_accelerator
@require_fsdp_version @require_fsdp_version
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
def setUp(self): def setUp(self):
...@@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): ...@@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
@parameterized.expand(params, name_func=_parameterized_custom_name_func) @parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_gpu @require_torch_multi_accelerator
@slow @slow
def test_basic_run(self, sharding_strategy, dtype): def test_basic_run(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False) launcher = get_launcher(distributed=True, use_accelerate=False)
...@@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): ...@@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(dtypes) @parameterized.expand(dtypes)
@require_torch_multi_gpu @require_torch_multi_accelerator
@slow @slow
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.") @unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
def test_basic_run_with_cpu_offload(self, dtype): def test_basic_run_with_cpu_offload(self, dtype):
...@@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): ...@@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func) @parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
@require_torch_multi_gpu @require_torch_multi_accelerator
@slow @slow
def test_training_and_can_resume_normally(self, state_dict_type): def test_training_and_can_resume_normally(self, state_dict_type):
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment