Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3712ac7f
Unverified
Commit
3712ac7f
authored
Nov 18, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 18, 2022
Browse files
[Gemini] add bert for MemtracerWrapper unintests (#1982)
parent
e481489a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
11 deletions
+32
-11
colossalai/gemini/memory_tracer/module_tracer_wrapper.py
colossalai/gemini/memory_tracer/module_tracer_wrapper.py
+3
-0
colossalai/gemini/ophooks/mem_trace_hook.py
colossalai/gemini/ophooks/mem_trace_hook.py
+1
-0
tests/test_gemini/test_mem_tracer.py
tests/test_gemini/test_mem_tracer.py
+23
-6
tests/test_zero/test_shard_model_v2.py
tests/test_zero/test_shard_model_v2.py
+5
-5
No files found.
colossalai/gemini/memory_tracer/module_tracer_wrapper.py
View file @
3712ac7f
...
...
@@ -28,6 +28,9 @@ class _Wrapper():
def
show_mem_stats
(
self
):
self
.
_ophook_list
[
0
].
show_mem_stats
()
def
named_buffers
(
self
):
return
self
.
_model
.
named_buffers
()
def
MemtracerWrapper
(
model
):
ophook_list
=
[
MemTracerOpHook
()]
...
...
colossalai/gemini/ophooks/mem_trace_hook.py
View file @
3712ac7f
...
...
@@ -7,6 +7,7 @@ from colossalai.gemini.ophooks import BaseOpHook
class
MemTracerOpHook
(
BaseOpHook
):
"""
TODO() what if parameters are sharded by multiple submodules.
register buff on its father node
"""
def
__init__
(
self
):
...
...
tests/test_gemini/test_mem_tracer.py
View file @
3712ac7f
from
functools
import
partial
import
pytest
import
torch
import
torch.
nn
as
nn
import
torch.
multiprocessing
as
mp
import
colossalai
from
colossalai.gemini.memory_tracer
import
MemtracerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
...
@@ -17,16 +22,20 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model
.
backward
(
loss
)
def
test_tracer
(
):
# reset the manager, in case that there exists memory information left
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
]
def
run_tracer
(
rank
,
world_size
,
port
,
grad_check
=
True
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'bert'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
# init model on cpu
model
=
MemtracerWrapper
(
model_builder
())
# TODO() memtrace hook can not handle buff registered on a non-leaf module (for example the BertEmbedding).
# a simple method is that always puts buff on cuda and viewed them as non-model data.
model
=
MemtracerWrapper
(
model_builder
(
grad_check
))
for
n
,
buff
in
model
.
named_buffers
():
buff
.
data
=
buff
.
data
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
1
:
break
...
...
@@ -38,5 +47,13 @@ def test_tracer():
# model._ophook_list[0].print_non_model_data()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
rerun_if_address_is_in_use
()
def
test_tracer
(
world_size
):
run_func
=
partial
(
run_tracer
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_tracer
()
test_tracer
(
1
)
tests/test_zero/test_shard_model_v2.py
View file @
3712ac7f
...
...
@@ -3,21 +3,21 @@
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
@
parameterize
(
"enable_autocast"
,
[
True
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment