Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d9332b4
Unverified
Commit
4d9332b4
authored
Apr 19, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 19, 2022
Browse files
[refactor] moving memtracer to gemini (#801)
parent
8711c706
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
7 deletions
+5
-7
tests/test_gemini/test_stateful_tensor_mgr.py
tests/test_gemini/test_stateful_tensor_mgr.py
+2
-2
tests/test_zero/test_init_context.py
tests/test_zero/test_init_context.py
+1
-1
tests/test_zero/test_mem_collector.py
tests/test_zero/test_mem_collector.py
+2
-2
tests/test_zero/test_zero_engine.py
tests/test_zero/test_zero_engine.py
+0
-2
No files found.
tests/test_
zero
/test_stateful_tensor_mgr.py
→
tests/test_
gemini
/test_stateful_tensor_mgr.py
View file @
4d9332b4
...
...
@@ -3,8 +3,8 @@ import colossalai
import
pytest
import
torch.multiprocessing
as
mp
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.
utils
.memory_tracer
import
MemStatsCollector
from
colossalai.
utils
.memory_tracer
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.
gemini
.memory_tracer
import
MemStatsCollector
from
colossalai.
gemini
.memory_tracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_set_process_memory_fraction
from
colossalai.gemini
import
StatefulTensorMgr
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
...
...
tests/test_zero/test_init_context.py
View file @
4d9332b4
...
...
@@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.
utils
.memory_tracer.model_data_memtracer
import
\
from
colossalai.
gemini
.memory_tracer.model_data_memtracer
import
\
colo_model_mem_usage
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.zero.init_ctx
import
ZeroInitContext
...
...
tests/test_zero/test_mem_collector.py
View file @
4d9332b4
...
...
@@ -14,7 +14,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
functools
import
partial
class
TestModel
(
torch
.
nn
.
Module
):
class
My
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -37,7 +37,7 @@ def run_mem_collector_testing():
colo_set_process_memory_fraction
(
fraction
)
shard_strategy
=
BucketTensorShardStrategy
()
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
TestModel
()
model
=
My
TestModel
()
model
=
ShardedModelV2
(
module
=
model
,
shard_strategy
=
shard_strategy
,
...
...
tests/test_zero/test_zero_engine.py
View file @
4d9332b4
...
...
@@ -91,8 +91,6 @@ def run_dist(rank, world_size, port, parallel_config):
# FIXME: enable this test in next PR
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment