Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1f992058
Unverified
Commit
1f992058
authored
Dec 06, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 06, 2022
Browse files
[Gemini] remove static tracer (#2083)
parent
28ef3f29
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
23 deletions
+15
-23
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+3
-21
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+10
-0
colossalai/nn/parallel/gemini_parallel.py
colossalai/nn/parallel/gemini_parallel.py
+1
-1
tests/test_tensor/model/test_model.py
tests/test_tensor/model/test_model.py
+1
-1
No files found.
colossalai/gemini/gemini_mgr.py
View file @
1f992058
...
...
@@ -26,27 +26,13 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
module
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
use_static_memstats
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_polocy_names
()
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
# self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self
.
use_static_memstats
=
use_static_memstats
if
policy_cls
.
need_mem_stats
:
if
use_static_memstats
:
assert
module
is
not
None
self
.
_mem_stats_collector
=
StaticMemStatsCollector
(
module
,
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
None
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_idx
:
int
=
-
1
...
...
@@ -60,10 +46,6 @@ class GeminiManager:
def
pre_iter
(
self
,
*
args
):
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
if
self
.
use_static_memstats
:
self
.
_mem_stats_collector
.
init_mem_stats
(
*
args
)
self
.
_warmup
=
False
else
:
self
.
_mem_stats_collector
.
start_collection
()
def
post_iter
(
self
):
...
...
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
1f992058
...
...
@@ -9,6 +9,16 @@ __all__ = ['RuntimeMemTracer']
class
RuntimeMemTracer
():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
=
torch
.
half
):
super
().
__init__
()
...
...
colossalai/nn/parallel/gemini_parallel.py
View file @
1f992058
...
...
@@ -50,5 +50,5 @@ class GeminiDDP(ZeroDDP):
hidden_dim
=
hidden_dim
,
search_range_mb
=
search_range_mb
,
min_chunk_size_mb
=
min_chunk_size_mb
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
module
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
super
().
__init__
(
module
,
gemini_manager
,
pin_memory
,
force_outputs_fp32
)
tests/test_tensor/model/test_model.py
View file @
1f992058
...
...
@@ -117,7 +117,7 @@ def run_1d_hybrid_tp(model_name):
else
:
output_torch
=
model_torch
(
data
,
label
)
loss_torch
=
output_torch
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
,
f
"model_name
{
model_name
}
failed"
torch
.
distributed
.
barrier
()
loss
.
backward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment