Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ea287207
Commit
ea287207
authored
Mar 10, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] global model data memory tracer (#360)
parent
cb34cd38
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
94 additions
and
4 deletions
+94
-4
colossalai/utils/commons/singleton_meta.py
colossalai/utils/commons/singleton_meta.py
+18
-0
colossalai/utils/memory_tracer/allocator.py
colossalai/utils/memory_tracer/allocator.py
+60
-0
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+9
-1
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+1
-1
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+6
-2
No files found.
colossalai/utils/commons/singleton_meta.py
0 → 100644
View file @
ea287207
class
SingletonMeta
(
type
):
"""
The Singleton class can be implemented in different ways in Python. Some
possible methods include: base class, decorator, metaclass. We will use the
metaclass because it is best suited for this purpose.
"""
_instances
=
{}
def
__call__
(
cls
,
*
args
,
**
kwargs
):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
if
cls
not
in
cls
.
_instances
:
instance
=
super
().
__call__
(
*
args
,
**
kwargs
)
cls
.
_instances
[
cls
]
=
instance
return
cls
.
_instances
[
cls
]
colossalai/utils/memory_tracer/allocator.py
0 → 100644
View file @
ea287207
import
torch
from
colossalai.utils.commons.singleton_meta
import
SingletonMeta
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Union
def
col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
if
isinstance
(
t
,
ShardedTensor
):
target
=
t
.
payload
else
:
target
=
t
return
target
.
numel
()
*
target
.
element_size
()
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
"""
A singleton to trace model data usage during runtime.
"""
def
__init__
(
self
)
->
None
:
self
.
_cpu_usage
=
0
self
.
_cuda_usage
=
0
def
trace_tensor
(
self
,
t
:
torch
.
Tensor
):
mem_use
=
col_tensor_mem_usage
(
t
)
if
t
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
+=
mem_use
elif
t
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
+=
mem_use
else
:
raise
RuntimeError
def
detach_tensor
(
self
,
t
:
torch
.
Tensor
):
mem_use
=
col_tensor_mem_usage
(
t
)
if
t
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
-=
mem_use
elif
t
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
-=
mem_use
else
:
raise
RuntimeError
@
property
def
cpu_usage
(
self
):
return
self
.
_cpu_usage
@
property
def
cuda_usage
(
self
):
return
self
.
_cuda_usage
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
def
col_allocate_payload
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
pass
def
col_release_payload
(
t
:
torch
.
Tensor
):
pass
colossalai/zero/init_ctx/init_context.py
View file @
ea287207
...
...
@@ -4,6 +4,7 @@ import torch
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
...
...
@@ -76,11 +77,16 @@ class InsertPostInitMethodToModuleSubClasses(object):
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
"""
r
"""
A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18.
"""
def
__init__
(
self
,
...
...
@@ -134,5 +140,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_data_sharded_tensor
])
GLOBAL_MODEL_DATA_TRACER
.
trace_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
if
param
.
col_attr
.
grad
and
self
.
shard_grad
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_grad_sharded_tensor
])
GLOBAL_MODEL_DATA_TRACER
.
trace_tensor
(
param
.
col_attr
.
_grad_sharded_tensor
.
payload
)
colossalai/zero/sharded_param/sharded_tensor.py
View file @
ea287207
...
...
@@ -7,7 +7,7 @@ class ShardedTensor(object):
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
r
"""
A tensor sharded in multiple processes.
A tensor sharded in multiple processes.
Constructed from an existing torch.Tensor instance.
"""
self
.
_payload
=
tensor
self
.
process_group
=
process_group
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
ea287207
...
...
@@ -13,7 +13,8 @@ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
,
Net
from
common
import
CONFIG
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -33,9 +34,12 @@ def run_dist(rank, world_size, port):
assert
param
.
col_attr
.
data
.
is_sharded
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
'cuda'
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
def
test_zero_init_context
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment