Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5bec3b21
Unverified
Commit
5bec3b21
authored
Nov 18, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 18, 2022
Browse files
[Gemini] open grad checkpoint when model building (#1984)
parent
c26f21d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
5 deletions
+7
-5
tests/test_gemini/test_mem_tracer.py
tests/test_gemini/test_mem_tracer.py
+7
-5
No files found.
tests/test_gemini/test_mem_tracer.py
View file @
5bec3b21
...
@@ -22,9 +22,10 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
...
@@ -22,9 +22,10 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model
.
backward
(
loss
)
model
.
backward
(
loss
)
def
run_tracer
(
rank
,
world_size
,
port
,
grad_check
=
True
):
def
run_tracer
(
rank
,
world_size
,
port
,
use_
grad_check
=
True
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'bert'
]
# test_models = ['bert']
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
...
@@ -32,7 +33,7 @@ def run_tracer(rank, world_size, port, grad_check=True):
...
@@ -32,7 +33,7 @@ def run_tracer(rank, world_size, port, grad_check=True):
# init model on cpu
# init model on cpu
# TODO() memtrace hook can not handle buff registered on a non-leaf module (for example the BertEmbedding).
# TODO() memtrace hook can not handle buff registered on a non-leaf module (for example the BertEmbedding).
# a simple method is that always puts buff on cuda and viewed them as non-model data.
# a simple method is that always puts buff on cuda and viewed them as non-model data.
model
=
MemtracerWrapper
(
model_builder
(
grad_check
))
model
=
MemtracerWrapper
(
model_builder
(
checkpoint
=
use_
grad_check
))
for
n
,
buff
in
model
.
named_buffers
():
for
n
,
buff
in
model
.
named_buffers
():
buff
.
data
=
buff
.
data
.
cuda
()
buff
.
data
=
buff
.
data
.
cuda
()
...
@@ -44,14 +45,15 @@ def run_tracer(rank, world_size, port, grad_check=True):
...
@@ -44,14 +45,15 @@ def run_tracer(rank, world_size, port, grad_check=True):
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
#
model._ophook_list[0].print_non_model_data()
model
.
_ophook_list
[
0
].
print_non_model_data
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_grad_check"
,
[
True
,
False
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_tracer
(
world_size
):
def
test_tracer
(
world_size
,
use_grad_check
):
run_func
=
partial
(
run_tracer
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_tracer
,
world_size
=
world_size
,
port
=
free_port
()
,
use_grad_check
=
use_grad_check
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment