Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
62b4ce73
Commit
62b4ce73
authored
Apr 12, 2022
by
FrankLeeeee
Committed by
Frank Lee
Apr 12, 2022
Browse files
[test] added missing decorators to model checkpointing tests
parent
1cb7bdad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
12 deletions
+12
-12
tests/test_utils/test_checkpoint/test_checkpoint_1d.py
tests/test_utils/test_checkpoint/test_checkpoint_1d.py
+3
-3
tests/test_utils/test_checkpoint/test_checkpoint_2d.py
tests/test_utils/test_checkpoint/test_checkpoint_2d.py
+3
-3
tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
+3
-3
tests/test_utils/test_checkpoint/test_checkpoint_3d.py
tests/test_utils/test_checkpoint/test_checkpoint_3d.py
+3
-3
No files found.
tests/test_utils/test_checkpoint/test_checkpoint_1d.py
View file @
62b4ce73
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
,
is_using_pp
from
colossalai.utils
import
free_port
,
is_using_pp
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.testing
import
rerun_on_exception
def
build_pipeline
(
model
):
def
build_pipeline
(
model
):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
def
check_checkpoint_1d
(
rank
,
world_size
,
port
):
def
check_checkpoint_1d
(
rank
,
world_size
,
port
):
config
=
dict
(
config
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
2
),
tensor
=
dict
(
size
=
4
,
mode
=
"1d"
)),)
parallel
=
dict
(
pipeline
=
dict
(
size
=
2
),
tensor
=
dict
(
size
=
4
,
mode
=
"1d"
)),
)
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
...
@@ -68,6 +67,7 @@ def check_checkpoint_1d(rank, world_size, port):
...
@@ -68,6 +67,7 @@ def check_checkpoint_1d(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_checkpoint_1d
():
def
test_checkpoint_1d
():
world_size
=
8
world_size
=
8
run_func
=
partial
(
check_checkpoint_1d
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_checkpoint_1d
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_utils/test_checkpoint/test_checkpoint_2d.py
View file @
62b4ce73
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
,
get_current_device
,
is_using_pp
from
colossalai.utils
import
free_port
,
get_current_device
,
is_using_pp
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.testing
import
rerun_on_exception
def
build_pipeline
(
model
):
def
build_pipeline
(
model
):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
def
check_checkpoint_2d
(
rank
,
world_size
,
port
):
def
check_checkpoint_2d
(
rank
,
world_size
,
port
):
config
=
dict
(
config
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
2
),
tensor
=
dict
(
size
=
4
,
mode
=
"2d"
)),)
parallel
=
dict
(
pipeline
=
dict
(
size
=
2
),
tensor
=
dict
(
size
=
4
,
mode
=
"2d"
)),
)
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
...
@@ -68,6 +67,7 @@ def check_checkpoint_2d(rank, world_size, port):
...
@@ -68,6 +67,7 @@ def check_checkpoint_2d(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_checkpoint_2d
():
def
test_checkpoint_2d
():
world_size
=
8
world_size
=
8
run_func
=
partial
(
check_checkpoint_2d
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_checkpoint_2d
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
View file @
62b4ce73
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
,
get_current_device
,
is_using_pp
from
colossalai.utils
import
free_port
,
get_current_device
,
is_using_pp
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.testing
import
rerun_on_exception
def
build_pipeline
(
model
):
def
build_pipeline
(
model
):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
def
check_checkpoint_2p5d
(
rank
,
world_size
,
port
):
def
check_checkpoint_2p5d
(
rank
,
world_size
,
port
):
config
=
dict
(
config
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
2
),
tensor
=
dict
(
size
=
4
,
depth
=
1
,
mode
=
"2.5d"
)),)
parallel
=
dict
(
pipeline
=
dict
(
size
=
2
),
tensor
=
dict
(
size
=
4
,
depth
=
1
,
mode
=
"2.5d"
)),
)
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
...
@@ -68,6 +67,7 @@ def check_checkpoint_2p5d(rank, world_size, port):
...
@@ -68,6 +67,7 @@ def check_checkpoint_2p5d(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_checkpoint_2p5d
():
def
test_checkpoint_2p5d
():
world_size
=
8
world_size
=
8
run_func
=
partial
(
check_checkpoint_2p5d
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_checkpoint_2p5d
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_utils/test_checkpoint/test_checkpoint_3d.py
View file @
62b4ce73
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
...
@@ -15,6 +15,7 @@ from colossalai.initialize import launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
,
get_current_device
,
is_using_pp
from
colossalai.utils
import
free_port
,
get_current_device
,
is_using_pp
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.utils.checkpointing
import
gather_pipeline_parallel_state_dict
,
load_checkpoint
,
save_checkpoint
from
colossalai.testing
import
rerun_on_exception
def
build_pipeline
(
model
):
def
build_pipeline
(
model
):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
...
@@ -38,9 +39,7 @@ def check_equal(A, B):
def
check_checkpoint_3d
(
rank
,
world_size
,
port
):
def
check_checkpoint_3d
(
rank
,
world_size
,
port
):
config
=
dict
(
config
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
8
,
mode
=
"3d"
)),)
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
8
,
mode
=
"3d"
)),
)
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
...
@@ -68,6 +67,7 @@ def check_checkpoint_3d(rank, world_size, port):
...
@@ -68,6 +67,7 @@ def check_checkpoint_3d(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_checkpoint_3d
():
def
test_checkpoint_3d
():
world_size
=
8
world_size
=
8
run_func
=
partial
(
check_checkpoint_3d
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_checkpoint_3d
,
world_size
=
world_size
,
port
=
free_port
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment