Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1f90a3b1
Unverified
Commit
1f90a3b1
authored
Mar 29, 2022
by
ver217
Committed by
GitHub
Mar 29, 2022
Browse files
[zero] polish ZeroInitContext (#540)
parent
c11ff81b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
23 additions
and
38 deletions
+23
-38
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+7
-14
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+7
-7
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+2
-4
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+1
-3
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
...est_zero_data_parallel/test_sharded_optim_with_sync_bn.py
+2
-3
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+2
-4
tests/test_zero_data_parallel/test_zero_engine.py
tests/test_zero_data_parallel/test_zero_engine.py
+2
-3
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
1f90a3b1
...
...
@@ -4,12 +4,11 @@ from typing import Optional
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.
utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.
logging
import
get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
def
_substitute_init_recursively
(
cls
,
func
):
...
...
@@ -107,20 +106,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
def
__init__
(
self
,
convert_fp16
:
bool
,
target_device
:
torch
.
device
,
shard_strategy
:
BaseShardStrategy
,
shard_param
:
bool
=
False
,
shard_grad
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
,
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
),
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
long
),
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
initialized_param_list
=
[]
...
...
@@ -157,11 +152,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
target_device
=
self
.
target_device
# convert to fp16 if necessary
if
self
.
convert_fp16
:
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
)
# convert to fp16
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
)
# move torch parameters to the target device
param
.
data
=
param
.
data
.
to
(
target_device
)
...
...
@@ -179,5 +173,4 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# We must cast them
for
buffer
in
module
.
buffers
(
recurse
=
False
):
buffer
.
data
=
buffer
.
data
.
to
(
device
=
torch
.
cuda
.
current_device
())
if
self
.
convert_fp16
:
buffer
.
data
=
cast_tensor_to_fp16
(
buffer
.
data
)
buffer
.
data
=
cast_tensor_to_fp16
(
buffer
.
data
)
tests/test_zero_data_parallel/test_init_context.py
View file @
1f90a3b1
...
...
@@ -7,16 +7,17 @@ import colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
from
colossalai.logging
import
get_dist_logger
from
colossalai.testing
import
parameterize
,
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
col_model_data_mem_usage
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
col_model_data_mem_usage
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.logging
import
get_dist_logger
from
common
import
CONFIG
...
...
@@ -36,8 +37,7 @@ def run_model_test(init_device_type, shard_strategy_class):
continue
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
with
ZeroInitContext
(
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
,
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
1f90a3b1
...
...
@@ -7,14 +7,13 @@ import colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
,
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -32,8 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
rm_torch_payload_on_the_fly
=
False
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
1f90a3b1
...
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
,
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
...
...
@@ -16,7 +16,6 @@ from colossalai.zero.sharded_model import ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -59,7 +58,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
f
'cpu:0'
)
if
cpu_offload
else
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
...
...
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
View file @
1f90a3b1
...
...
@@ -10,11 +10,11 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.testing
import
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
torchvision.models
import
resnet50
from
colossalai.testing
import
rerun_on_exception
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -30,8 +30,7 @@ def run_dist(rank, world_size, port):
port
=
port
,
backend
=
'nccl'
)
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
gpc
.
config
.
zero
.
model_config
.
shard_strategy
,
shard_param
=
True
):
model
=
resnet50
()
...
...
tests/test_zero_data_parallel/test_state_dict.py
View file @
1f90a3b1
...
...
@@ -8,13 +8,12 @@ import colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
,
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
...
...
@@ -28,8 +27,7 @@ def run_zero_state_dict(shard_strategy_class):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
...
...
tests/test_zero_data_parallel/test_zero_engine.py
View file @
1f90a3b1
...
...
@@ -9,11 +9,11 @@ import torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.core
import
global_context
as
gpc
from
colossalai.testing
import
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -32,8 +32,7 @@ def run_dist(rank, world_size, port, parallel_config):
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
hasattr
(
gpc
.
config
,
'fp16'
),
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
gpc
.
config
.
zero
.
model_config
.
shard_strategy
,
shard_param
=
True
):
colo_model
=
model_builder
(
checkpoint
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment