Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7ef3507a
"examples/tutorial/vscode:/vscode.git/clone" did not exist on "de56b563b96bb03b6df058fe704a19f24d444bbc"
Unverified
Commit
7ef3507a
authored
Mar 25, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 25, 2022
Browse files
[zero] show model data cuda memory usage after zero context init. (#515)
parent
a2e61d61
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
9 deletions
+38
-9
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+13
-2
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+18
-4
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-3
tests/test_utils/test_commons.py
tests/test_utils/test_commons.py
+2
-0
No files found.
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
7ef3507a
...
...
@@ -22,13 +22,24 @@ class ModelDataTracer(metaclass=SingletonMeta):
def
__init__
(
self
)
->
None
:
self
.
_cuda_usage
=
0
self
.
_start_flag
=
False
def
add_tensor
(
self
,
t
:
torch
.
Tensor
):
def
start
(
self
)
->
None
:
self
.
_start_flag
=
True
def
close
(
self
)
->
None
:
self
.
_start_flag
=
False
def
add_tensor
(
self
,
t
:
torch
.
Tensor
)
->
None
:
if
not
self
.
_start_flag
:
return
assert
isinstance
(
t
,
torch
.
Tensor
),
f
"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use
=
_col_tensor_mem_usage
(
t
)
self
.
_cuda_usage
+=
mem_use
def
delete_tensor
(
self
,
t
:
torch
.
Tensor
):
def
delete_tensor
(
self
,
t
:
torch
.
Tensor
)
->
None
:
if
not
self
.
_start_flag
:
return
assert
isinstance
(
t
,
torch
.
Tensor
),
f
"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use
=
_col_tensor_mem_usage
(
t
)
self
.
_cuda_usage
-=
mem_use
...
...
colossalai/zero/init_ctx/init_context.py
View file @
7ef3507a
...
...
@@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
# Inserts _post_init_method at the end of init method
...
...
@@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
def
_pre_context_exec
(
self
):
"""
The Callback function when entering the context
"""
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
def
_post_context_exec
(
self
):
"""The callback function when
the
context
exits
.
"""The callback function when
exiting
context.
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
...
...
@@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
col_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
GLOBAL_MODEL_DATA_TRACER
.
close
()
cuda_mem_MB
=
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
/
1e6
self
.
logger
.
info
(
f
"Existing ZeRO Context Model Data CUDA Memory Usage
{
cuda_mem_MB
}
MB"
,
[
0
])
def
_post_init_method
(
self
,
module
):
r
"""The function to call at the end of the constructor of each nn.Module.
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
for
param
in
module
.
parameters
():
# avoid adapting a param to ShardedParam twice
...
...
@@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
.
payload
)
if
param
.
col_attr
.
sharded_data_tensor
.
device
.
type
==
'cuda'
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
.
payload
)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7ef3507a
...
...
@@ -23,9 +23,11 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
class
ShardedModelV2
(
nn
.
Module
):
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
"""
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
Only 1/#nproc of parameters, gradients are stored in local CUDA memory, so forward and backward
passes can be executed with limited CUDA memory budget.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
Args:
...
...
tests/test_utils/test_commons.py
View file @
7ef3507a
...
...
@@ -16,6 +16,7 @@ def run_tensor_move(rank):
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
src_t
=
torch
.
ones
(
2
,
3
).
cuda
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
src_t
)
...
...
@@ -39,6 +40,7 @@ def run_tensor_move(rank):
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
24
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
assert
(
torch
.
sum
(
tgt_t
.
payload
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
test_tensor_move
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment