Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
799d105b
Commit
799d105b
authored
Mar 08, 2022
by
jiaruifang
Committed by
Frank Lee
Mar 11, 2022
Browse files
using pytest parametrize
parent
dec24561
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
48 deletions
+43
-48
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+4
-9
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+24
-22
tests/test_zero_data_parallel/test_zero_param_mgr.py
tests/test_zero_data_parallel/test_zero_param_mgr.py
+15
-17
No files found.
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
799d105b
...
@@ -30,12 +30,7 @@ def run_fwd_bwd(model, x, enable_autocast=False):
...
@@ -30,12 +30,7 @@ def run_fwd_bwd(model, x, enable_autocast=False):
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
Net
(
checkpoint
=
True
).
cuda
()
model
=
Net
(
checkpoint
=
True
).
cuda
()
zero_model
=
copy
.
deepcopy
(
model
)
zero_model
=
copy
.
deepcopy
(
model
)
...
@@ -52,11 +47,11 @@ def run_dist(rank, world_size, port):
...
@@ -52,11 +47,11 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_shard_model_v2
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
world_size
=
2
def
test_shard_model_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_shard_model_v2
()
test_shard_model_v2
(
world_size
=
2
)
tests/test_zero_data_parallel/test_shard_param.py
View file @
799d105b
...
@@ -4,19 +4,21 @@
...
@@ -4,19 +4,21 @@
from
copy
import
deepcopy
from
copy
import
deepcopy
from
functools
import
partial
from
functools
import
partial
import
colossalai
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
,
allclose
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
,
allclose
def
run_shard_tensor
(
rank
,
world_size
,
port
):
def
_
run_shard_tensor
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
assert
list
(
t
.
origin_shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
origin_shape
)
==
[
world_size
*
2
,
3
]
...
@@ -32,9 +34,9 @@ def run_shard_tensor(rank, world_size, port):
...
@@ -32,9 +34,9 @@ def run_shard_tensor(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_shard_tensor
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
world_size
=
2
def
test_shard_tensor
(
world_size
):
run_func
=
partial
(
run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
_
run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
@@ -52,8 +54,8 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -52,8 +54,8 @@ def _run_shard_param_v2(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_shard_param_v2
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
world_size
=
2
def
test_shard_param_v2
(
world_size
):
run_func
=
partial
(
_run_shard_param_v2
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
_run_shard_param_v2
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
@@ -86,40 +88,40 @@ def _run_test_shard_param(rank, world_size, port):
...
@@ -86,40 +88,40 @@ def _run_test_shard_param(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_shard_param
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
world_size
=
2
def
test_shard_param
(
world_size
):
run_func
=
partial
(
_run_test_shard_param
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
_run_test_shard_param
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
run_init_shard_param
(
rank
,
world_size
,
port
):
def
_
run_init_shard_param
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
param
=
torch
.
nn
.
Parameter
(
data
=
torch
.
rand
(
2
,
3
))
param
=
torch
.
nn
.
Parameter
(
data
=
torch
.
rand
(
world_size
,
3
))
sparam
=
ShardedParam
(
param
,
None
,
True
)
sparam
=
ShardedParam
(
param
,
None
,
True
)
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
3
])
assert
(
list
(
payload
.
shape
)
==
[
3
])
del
sparam
del
sparam
param_shape
=
(
2
,
3
)
param_shape
=
(
world_size
,
3
)
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
True
,
device
=
torch
.
device
(
'cpu'
))
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
True
,
device
=
torch
.
device
(
'cpu'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
3
])
assert
(
list
(
payload
.
shape
)
==
[
3
])
param_shape
=
(
2
,
3
)
param_shape
=
(
world_size
,
3
)
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
False
,
device
=
torch
.
device
(
'cpu'
))
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
False
,
device
=
torch
.
device
(
'cpu'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
2
,
3
])
assert
(
list
(
payload
.
shape
)
==
[
world_size
,
3
])
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_init_shard_param
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
world_size
=
2
def
test_init_shard_param
(
world_size
):
run_func
=
partial
(
run_init_shard_param
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
_
run_init_shard_param
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_shard_tensor
()
test_shard_tensor
(
2
)
test_shard_param
()
test_shard_param
(
2
)
test_shard_param_v2
()
test_shard_param_v2
(
2
)
test_init_shard_param
()
test_init_shard_param
(
4
)
tests/test_zero_data_parallel/test_zero_param_mgr.py
View file @
799d105b
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
os
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
import
colossalai
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.zero.sharded_model.param_manager
import
Zero3ParameterManager
from
colossalai.zero.sharded_model.param_manager
import
Zero3ParameterManager
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
common
import
CONFIG
from
common
import
CONFIG
def
run_shard_shape_check
(
rank
,
world_size
,
port
):
def
run_shard_shape_check
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
torch
.
nn
.
Linear
(
2
,
4
*
world_size
)
model
=
torch
.
nn
.
Linear
(
2
,
4
*
world_size
)
gpc
.
init_parallel_groups
()
gpc
.
init_parallel_groups
()
Zero3ParameterManager
(
module
=
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
offload_config
=
CONFIG
.
get
(
'offload_param_config'
))
Zero3ParameterManager
(
module
=
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
offload_config
=
CONFIG
.
get
(
'offload_param_config'
))
assert
(
model
.
weight
.
numel
()
==
4
*
2
)
assert
(
model
.
weight
.
numel
()
==
4
*
2
)
assert
(
model
.
bias
.
numel
()
==
4
)
assert
(
model
.
bias
.
numel
()
==
4
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_run_shard_shape
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
world_size
=
2
def
test_run_shard_shape
(
world_size
):
run_func
=
partial
(
run_shard_shape_check
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_shard_shape_check
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_run_shard_shape
()
test_run_shard_shape
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment