Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
82c7e879
Unverified
Commit
82c7e879
authored
Nov 01, 2023
by
Hz, Ji
Committed by
GitHub
Nov 01, 2023
Browse files
device agnostic fsdp testing (#27120)
* make fsdp test cases device agnostic * make style
parent
7d8ff362
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
tests/fsdp/test_fsdp.py
tests/fsdp/test_fsdp.py
+11
-10
No files found.
tests/fsdp/test_fsdp.py
View file @
82c7e879
...
@@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
...
@@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
TestCasePlus
,
TestCasePlus
,
backend_device_count
,
execute_subprocess_async
,
execute_subprocess_async
,
get_gpu_count
,
mockenv_context
,
mockenv_context
,
require_accelerate
,
require_accelerate
,
require_fsdp
,
require_fsdp
,
require_torch_
gpu
,
require_torch_
accelerator
,
require_torch_multi_
gpu
,
require_torch_multi_
accelerator
,
slow
,
slow
,
torch_device
,
)
)
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
FSDPOption
,
set_seed
from
transformers.trainer_utils
import
FSDPOption
,
set_seed
from
transformers.utils
import
is_accelerate_available
,
is_torch_bf16_
gpu_
available
from
transformers.utils
import
is_accelerate_available
,
is_torch_bf16_available
_on_device
if
is_torch_available
():
if
is_torch_available
():
...
@@ -46,7 +47,7 @@ else:
...
@@ -46,7 +47,7 @@ else:
# default torch.distributed port
# default torch.distributed port
DEFAULT_MASTER_PORT
=
"10999"
DEFAULT_MASTER_PORT
=
"10999"
dtypes
=
[
"fp16"
]
dtypes
=
[
"fp16"
]
if
is_torch_bf16_
gpu_
available
(
):
if
is_torch_bf16_available
_on_device
(
torch_device
):
dtypes
+=
[
"bf16"
]
dtypes
+=
[
"bf16"
]
sharding_strategies
=
[
"full_shard"
,
"shard_grad_op"
]
sharding_strategies
=
[
"full_shard"
,
"shard_grad_op"
]
state_dict_types
=
[
"FULL_STATE_DICT"
,
"SHARDED_STATE_DICT"
]
state_dict_types
=
[
"FULL_STATE_DICT"
,
"SHARDED_STATE_DICT"
]
...
@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
...
@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
# - it won't be able to handle that
# - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data)
# results with mode gpus because we use very little data)
num_gpus
=
min
(
2
,
get_gpu_count
(
))
if
distributed
else
1
num_gpus
=
min
(
2
,
backend_device_count
(
torch_device
))
if
distributed
else
1
master_port
=
get_master_port
(
real_launcher
=
True
)
master_port
=
get_master_port
(
real_launcher
=
True
)
if
use_accelerate
:
if
use_accelerate
:
return
f
"""accelerate launch
return
f
"""accelerate launch
...
@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
...
@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
@
require_accelerate
@
require_accelerate
@
require_torch_
gpu
@
require_torch_
accelerator
@
require_fsdp_version
@
require_fsdp_version
class
TrainerIntegrationFSDP
(
TestCasePlus
,
TrainerIntegrationCommon
):
class
TrainerIntegrationFSDP
(
TestCasePlus
,
TrainerIntegrationCommon
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
...
@@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
self
.
assertEqual
(
os
.
environ
.
get
(
"ACCELERATE_USE_FSDP"
,
"false"
),
"true"
)
self
.
assertEqual
(
os
.
environ
.
get
(
"ACCELERATE_USE_FSDP"
,
"false"
),
"true"
)
@
parameterized
.
expand
(
params
,
name_func
=
_parameterized_custom_name_func
)
@
parameterized
.
expand
(
params
,
name_func
=
_parameterized_custom_name_func
)
@
require_torch_multi_
gpu
@
require_torch_multi_
accelerator
@
slow
@
slow
def
test_basic_run
(
self
,
sharding_strategy
,
dtype
):
def
test_basic_run
(
self
,
sharding_strategy
,
dtype
):
launcher
=
get_launcher
(
distributed
=
True
,
use_accelerate
=
False
)
launcher
=
get_launcher
(
distributed
=
True
,
use_accelerate
=
False
)
...
@@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
...
@@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
@
parameterized
.
expand
(
dtypes
)
@
parameterized
.
expand
(
dtypes
)
@
require_torch_multi_
gpu
@
require_torch_multi_
accelerator
@
slow
@
slow
@
unittest
.
skipIf
(
not
is_torch_greater_or_equal_than_2_1
,
reason
=
"This test on pytorch 2.0 takes 4 hours."
)
@
unittest
.
skipIf
(
not
is_torch_greater_or_equal_than_2_1
,
reason
=
"This test on pytorch 2.0 takes 4 hours."
)
def
test_basic_run_with_cpu_offload
(
self
,
dtype
):
def
test_basic_run_with_cpu_offload
(
self
,
dtype
):
...
@@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
...
@@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
@
parameterized
.
expand
(
state_dict_types
,
name_func
=
_parameterized_custom_name_func
)
@
parameterized
.
expand
(
state_dict_types
,
name_func
=
_parameterized_custom_name_func
)
@
require_torch_multi_
gpu
@
require_torch_multi_
accelerator
@
slow
@
slow
def
test_training_and_can_resume_normally
(
self
,
state_dict_type
):
def
test_training_and_can_resume_normally
(
self
,
state_dict_type
):
output_dir
=
self
.
get_auto_remove_tmp_dir
(
"./xxx"
,
after
=
False
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
"./xxx"
,
after
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment