Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
11bddb6e
Commit
11bddb6e
authored
Mar 08, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] update zero context init with the updated test utils (#327)
parent
6268446b
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
97 additions
and
50 deletions
+97
-50
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+8
-8
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+19
-10
colossalai/zero/shard_utils/base_shard_strategy.py
colossalai/zero/shard_utils/base_shard_strategy.py
+5
-6
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+14
-2
tests/components_to_test/nested_model.py
tests/components_to_test/nested_model.py
+13
-6
tests/components_to_test/repeated_computed_layer.py
tests/components_to_test/repeated_computed_layer.py
+9
-3
tests/components_to_test/resnet.py
tests/components_to_test/resnet.py
+9
-3
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+3
-2
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+15
-10
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+2
-0
No files found.
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
11bddb6e
from
re
import
S
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch
from
.
import
BaseOpHook
...
...
@@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
time
import
sleep
,
time
import
pickle
from
typing
import
Union
,
Optional
from
typing
import
Optional
from
colossalai.core
import
global_context
as
gpc
...
...
@@ -25,6 +24,7 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
class
AsyncMemoryMonitor
:
def
__init__
(
self
,
power
=
10
):
"""
An Async Mem Monitor runing during computing.
...
...
@@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook):
_data_prefix (string): the prefix of the stats data file
_rank (int): the rank of current node
'''
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
super
().
__init__
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
...
...
@@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook):
# every `refreshrate` times, refresh the file
if
self
.
valid_iter
!=
0
and
self
.
valid_iter
%
self
.
refreshrate
==
0
:
# output file info
self
.
_logger
.
info
(
f
'dump a memory statistics as pickle to
{
self
.
_dataprefix
}
-
{
self
.
_rank
}
.pkl'
)
self
.
_logger
.
info
(
f
'dump a memory statistics as pickle to
{
self
.
_dataprefix
}
-
{
self
.
_rank
}
.pkl'
)
self
.
save_results
()
self
.
_count
+=
1
self
.
_logger
.
debug
(
f
'data file has been refreshed
{
self
.
_count
}
times'
)
...
...
colossalai/zero/init_ctx/init_context.py
View file @
11bddb6e
...
...
@@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
"""
def
__init__
(
self
,
def
__init__
(
self
,
convert_fp16
:
bool
,
convert_cuda
:
bool
,
shard_strategy
:
BaseShardStrategy
,
shard_param
:
bool
=
False
,
shard_grad
:
bool
=
False
,
):
rm_torch_payload_on_the_fly
=
False
):
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
convert_cuda
=
convert_cuda
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
initialized_param_list
=
[]
def
_post_context_exec
(
self
):
"""The callback function when the context exits.
"""
pass
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
def
_post_init_method
(
self
,
module
):
r
"""The function to call at the end of the constructor of each nn.Module.
...
...
@@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
ca_attr
=
ShardedParamV2
(
param
)
param
.
ca_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
ca_attr
.
_data_sharded_tensor
])
if
param
.
ca_attr
.
grad
and
self
.
shard_grad
:
...
...
colossalai/zero/shard_utils/base_shard_strategy.py
View file @
11bddb6e
...
...
@@ -7,6 +7,11 @@ from typing import List, Optional
class
BaseShardStrategy
(
ABC
):
def
__init__
(
self
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
Args:
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
"""
self
.
process_group
=
process_group
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
...
...
@@ -14,14 +19,8 @@ class BaseShardStrategy(ABC):
@
abstractmethod
def
shard
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
r
"""
sharded the memory of tensor on multiple processes.
"""
pass
@
abstractmethod
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
r
"""
duplicate tensor payload on each processes.
"""
pass
colossalai/zero/sharded_param/sharded_param.py
View file @
11bddb6e
...
...
@@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
class
ShardedParamV2
(
object
):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
self
.
_data_sharded_tensor
=
ShardedTensor
(
param
.
data
,
process_group
)
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
self
.
_grad_sharded_tensor
=
ShardedTensor
(
param
.
grad
,
process_group
)
...
...
@@ -19,7 +22,16 @@ class ShardedParamV2(object):
self
.
_grad_sharded_tensor
=
None
# make sure the shared param is the only owner of payload
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
# The param.data maybe used to init the other part of the model.
# For example: File "resnet.py", line 190, in __init__
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self
.
param
=
param
if
rm_torch_payload
:
self
.
remove_torch_payload
()
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
@
property
def
data
(
self
):
...
...
tests/components_to_test/nested_model.py
View file @
11bddb6e
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.nn
import
CheckpointModule
from
.utils
import
DummyDataGenerator
from
.registry
import
non_distributed_component_funcs
...
...
@@ -15,10 +16,10 @@ class SubNet(nn.Module):
return
F
.
linear
(
x
,
weight
,
self
.
bias
)
class
NestedNet
(
nn
.
Module
):
class
NestedNet
(
Checkpoint
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
(
checkpoint
)
self
.
fc1
=
nn
.
Linear
(
5
,
5
)
self
.
sub_fc
=
SubNet
(
5
)
self
.
fc2
=
nn
.
Linear
(
5
,
2
)
...
...
@@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator):
@
non_distributed_component_funcs
.
register
(
name
=
'nested_model'
)
def
get_training_components
():
model
=
NestedNet
()
def
model_builder
(
checkpoint
):
return
NestedNet
(
checkpoint
)
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model
,
trainloader
,
testloader
,
optim
,
criterion
return
model
_builder
,
trainloader
,
testloader
,
optim
_builder
,
criterion
tests/components_to_test/repeated_computed_layer.py
View file @
11bddb6e
...
...
@@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator):
@
non_distributed_component_funcs
.
register
(
name
=
'repeated_computed_layers'
)
def
get_training_components
():
model
=
NetWithRepeatedlyComputedLayers
(
checkpoint
=
True
)
def
model_builder
(
checkpoint
=
True
):
return
NetWithRepeatedlyComputedLayers
(
checkpoint
)
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model
,
trainloader
,
testloader
,
optim
,
criterion
return
model
_builder
,
trainloader
,
testloader
,
optim
_builder
,
criterion
tests/components_to_test/resnet.py
View file @
11bddb6e
...
...
@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
@
non_distributed_component_funcs
.
register
(
name
=
'resnet18'
)
def
get_resnet_training_components
():
model
=
resnet18
(
num_classes
=
10
)
def
model_builder
(
checkpoint
=
False
):
return
resnet18
(
num_classes
=
10
)
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model
,
trainloader
,
testloader
,
optim
,
criterion
return
model
_builder
,
trainloader
,
testloader
,
optim
_builder
,
criterion
tests/test_engine/test_engine.py
View file @
11bddb6e
...
...
@@ -16,10 +16,11 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def
run_train
():
for
get_components_func
in
non_distributed_component_funcs
:
model
,
train_dataloader
,
_
,
optimizer
,
criterion
=
get_components_func
()
model
_builder
,
train_dataloader
,
_
,
optimizer
_builder
,
criterion
=
get_components_func
()
model
=
model_builder
(
checkpoint
=
False
)
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
_builder
(
model
)
,
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
11bddb6e
...
...
@@ -9,16 +9,21 @@ import torch
import
torch.multiprocessing
as
mp
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
common
import
CONFIG
,
Net
from
common
import
CONFIG
from
colossalai.utils
import
free_port
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
TensorShardStrategy
(),
shard_param
=
True
):
# Note Net(checkpoint=True).cuda() moving to cuda is useless
model
=
Net
(
checkpoint
=
True
)
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
TensorShardStrategy
(),
shard_param
=
True
):
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
11bddb6e
...
...
@@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port):
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
allclose
(
sparam
.
data
,
param_ref
.
data
)
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment