Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8cf7ff08
"vscode:/vscode.git/clone" did not exist on "b2e0d502b8b9b7d4e6263fd97dff9974eace9a60"
Commit
8cf7ff08
authored
Mar 18, 2022
by
ver217
Browse files
polish code
parent
e99af94a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
28 deletions
+24
-28
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+11
-14
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+4
-5
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+4
-5
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+5
-4
No files found.
tests/test_zero_data_parallel/common.py
View file @
8cf7ff08
import
imp
from
functools
import
partial
import
torch
import
torch.distributed
as
dist
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.utils
import
checkpoint
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
...
...
@@ -20,8 +18,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
use_memory_tracer
=
False
,
shard_strategy
=
TensorShardStrategy
)
_ZERO_OPTIMIZER_CONFIG
=
dict
(
cpu_offload
=
False
,
_ZERO_OPTIMIZER_CONFIG
=
dict
(
cpu_offload
=
False
,
initial_scale
=
2
**
5
,
min_scale
=
1
,
growth_factor
=
2
,
...
...
@@ -35,7 +32,7 @@ ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero
=
dict
(
model_config
=
_ZERO_MODEL_CONFIG
,
optimizer_config
=
_ZERO_OPTIMIZER_CONFIG
,
),
),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
CONFIG
=
dict
(
fp16
=
dict
(
mode
=
None
,),
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
8cf7ff08
...
...
@@ -10,8 +10,7 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
...
...
@@ -22,10 +21,10 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@
parameterize
(
"enable_autocast"
,
[
True
])
@
parameterize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
enable_autocast
,
shard_strategy
):
@
parameterize
(
"shard_strategy
_class
"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
enable_autocast
,
shard_strategy
_class
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
shard_strategy
=
shard_strategy
_class
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
8cf7ff08
...
...
@@ -9,8 +9,7 @@ from colossalai.nn.optimizer import CPUAdam
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
...
...
@@ -41,10 +40,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@
parameterize
(
"cpu_offload"
,
[
True
,
False
])
@
parameterize
(
"use_cpuadam"
,
[
True
,
False
])
@
parameterize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
_run_test_sharded_optim_v2
(
cpu_offload
,
shard_strategy
,
use_cpuadam
):
@
parameterize
(
"shard_strategy
_class
"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
_run_test_sharded_optim_v2
(
cpu_offload
,
shard_strategy
_class
,
use_cpuadam
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
shard_strategy
=
shard_strategy
_class
()
if
use_cpuadam
and
cpu_offload
is
False
:
return
...
...
tests/test_zero_data_parallel/test_state_dict.py
View file @
8cf7ff08
...
...
@@ -8,20 +8,21 @@ import colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.testing
import
parameterize
from
common
import
CONFIG
@
parameterize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_zero_state_dict
(
shard_strategy
):
@
parameterize
(
"shard_strategy
_class
"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_zero_state_dict
(
shard_strategy
_class
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
shard_strategy
=
shard_strategy
()
shard_strategy
=
shard_strategy
_class
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment