Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3601b2ba
Unverified
Commit
3601b2ba
authored
Mar 25, 2022
by
Frank Lee
Committed by
GitHub
Mar 25, 2022
Browse files
[test] fixed rerun_on_exception and adapted test cases (#487)
parent
4d322b79
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
24 additions
and
2 deletions
+24
-2
tests/test_trainer/test_trainer_with_pipe_schedule.py
tests/test_trainer/test_trainer_with_pipe_schedule.py
+2
-0
tests/test_utils/test_commons.py
tests/test_utils/test_commons.py
+2
-1
tests/test_utils/test_gradient_accumluation.py
tests/test_utils/test_gradient_accumluation.py
+2
-0
tests/test_utils/test_zero_gradient_clippling.py
tests/test_utils/test_zero_gradient_clippling.py
+2
-1
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+2
-0
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+2
-0
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+3
-0
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+2
-0
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
...est_zero_data_parallel/test_sharded_optim_with_sync_bn.py
+2
-0
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+2
-0
tests/test_zero_data_parallel/test_zero_engine.py
tests/test_zero_data_parallel/test_zero_engine.py
+3
-0
No files found.
tests/test_trainer/test_trainer_with_pipe_schedule.py
View file @
3601b2ba
...
...
@@ -17,6 +17,7 @@ from torch.optim import Adam
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.models
import
resnet18
from
colossalai.testing
import
rerun_on_exception
BATCH_SIZE
=
4
IMG_SIZE
=
32
...
...
@@ -85,6 +86,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_trainer_with_pipeline
():
world_size
=
4
run_func
=
partial
(
run_trainer_with_pipeline
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_utils/test_commons.py
View file @
3601b2ba
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
colossalai.zero.sharded_param
import
ShardedTensor
import
colossalai
...
...
@@ -47,6 +47,7 @@ def run_tensor_move(rank):
GLOBAL_MODEL_DATA_TRACER
.
close
()
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_tensor_move
():
mp
.
spawn
(
run_tensor_move
,
nprocs
=
1
)
...
...
tests/test_utils/test_gradient_accumluation.py
View file @
3601b2ba
...
...
@@ -10,6 +10,7 @@ import torch.nn as nn
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
free_port
,
get_dataloader
from
colossalai.testing
import
rerun_on_exception
from
torch.optim
import
Adam
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
...
...
@@ -86,6 +87,7 @@ def run_no_pipeline(rank, world_size, port):
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_engine
():
world_size
=
4
func
=
partial
(
run_no_pipeline
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_utils/test_zero_gradient_clippling.py
View file @
3601b2ba
...
...
@@ -14,9 +14,9 @@ from colossalai.logging import disable_existing_loggers
from
colossalai.utils
import
checkpoint
,
clip_grad_norm_fp32
,
free_port
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.utils
import
clip_grad_norm_
from
colossalai.testing
import
parameterize
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
functools
import
partial
from
colossalai.testing
import
parameterize
,
rerun_on_exception
def
checkpoint_wrapper
(
module
,
enable
=
True
):
...
...
@@ -102,6 +102,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_zero_clip_grad
():
world_size
=
4
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
3601b2ba
...
...
@@ -14,6 +14,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
...
...
@@ -57,6 +58,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_zero_init_context
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
3601b2ba
...
...
@@ -14,6 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -63,6 +64,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_shard_model_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
3601b2ba
...
...
@@ -10,6 +10,7 @@ from colossalai.utils import free_port
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_param
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.testing
import
rerun_on_exception
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
...
...
@@ -35,6 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_shard_tensor
(
world_size
):
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
@@ -55,6 +57,7 @@ def _run_shard_param_v2(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_shard_param_v2
(
world_size
):
run_func
=
partial
(
_run_shard_param_v2
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
3601b2ba
...
...
@@ -15,6 +15,7 @@ from colossalai.zero.sharded_model import ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -106,6 +107,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_sharded_optim_v2
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
View file @
3601b2ba
...
...
@@ -14,6 +14,7 @@ from colossalai.utils import free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
torchvision.models
import
resnet50
from
colossalai.testing
import
rerun_on_exception
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -71,6 +72,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_sharded_optim_with_sync_bn
():
"""
This test is to make sure that buffers are synchronized between ranks
...
...
tests/test_zero_data_parallel/test_state_dict.py
View file @
3601b2ba
...
...
@@ -14,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
...
...
@@ -51,6 +52,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_zero_state_dict
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_zero_data_parallel/test_zero_engine.py
View file @
3601b2ba
...
...
@@ -13,6 +13,7 @@ from colossalai.utils import free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -96,6 +97,7 @@ def run_dist(rank, world_size, port, parallel_config):
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_mp_engine
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
parallel_config
=
MP_PARALLEL_CONFIG
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
@@ -103,6 +105,7 @@ def test_mp_engine(world_size):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_zero_engine
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
parallel_config
=
ZERO_PARALLEL_CONFIG
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment