Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
20e255d4
Unverified
Commit
20e255d4
authored
Nov 07, 2022
by
Zihao
Committed by
GitHub
Nov 07, 2022
Browse files
MemStatsCollectorStatic (#1765)
parent
327d07c4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
142 additions
and
11 deletions
+142
-11
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+23
-5
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+107
-1
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+1
-1
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+11
-4
No files found.
colossalai/gemini/gemini_mgr.py
View file @
20e255d4
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
,
MemStatsCollectorStatic
from
.placement_policy
import
PlacementPolicyFactory
from
.placement_policy
import
PlacementPolicyFactory
...
@@ -26,12 +26,26 @@ class GeminiManager:
...
@@ -26,12 +26,26 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
"""
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
)
->
None
:
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
module
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
use_static_memstats
:
bool
=
False
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_polocy_names
()
assert
placement_policy
in
PlacementPolicyFactory
.
get_polocy_names
()
self
.
policy_name
=
placement_policy
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
self
.
_chunk_manager
=
chunk_manager
self
.
_mem_stats_collector
=
MemStatsCollectorV2
(
chunk_manager
)
if
policy_cls
.
need_mem_stats
else
None
# self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
self
.
use_static_memstats
=
use_static_memstats
if
policy_cls
.
need_mem_stats
:
if
use_static_memstats
:
assert
module
is
not
None
self
.
_mem_stats_collector
=
MemStatsCollectorStatic
(
module
,
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
MemStatsCollectorV2
(
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
None
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_idx
:
int
=
-
1
self
.
_compute_idx
:
int
=
-
1
...
@@ -43,9 +57,13 @@ class GeminiManager:
...
@@ -43,9 +57,13 @@ class GeminiManager:
self
.
_warmup
=
True
self
.
_warmup
=
True
self
.
_comp_cuda_demand_time
=
0
self
.
_comp_cuda_demand_time
=
0
def
pre_iter
(
self
):
def
pre_iter
(
self
,
*
args
):
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
self
.
_mem_stats_collector
.
start_collection
()
if
self
.
use_static_memstats
:
self
.
_mem_stats_collector
.
init_mem_stats
(
*
args
)
self
.
_warmup
=
False
else
:
self
.
_mem_stats_collector
.
start_collection
()
def
post_iter
(
self
):
def
post_iter
(
self
):
"""This function must be called when each iteration finishes
"""This function must be called when each iteration finishes
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
20e255d4
...
@@ -5,8 +5,16 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
...
@@ -5,8 +5,16 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.chunk
import
ChunkManager
import
torch
import
torch
import
torch.nn
as
nn
import
time
import
time
from
typing
import
List
from
typing
import
List
,
Optional
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
(
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
,
parameter_size
)
from
torch.fx
import
symbolic_trace
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
class
MemStatsCollector
:
class
MemStatsCollector
:
...
@@ -150,3 +158,101 @@ class MemStatsCollectorV2(MemStatsCollector):
...
@@ -150,3 +158,101 @@ class MemStatsCollectorV2(MemStatsCollector):
@
property
@
property
def
cuda_margin_mem
(
self
)
->
float
:
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
class
MemStatsCollectorStatic
(
MemStatsCollectorV2
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
\ No newline at end of file
colossalai/nn/parallel/data_parallel.py
View file @
20e255d4
...
@@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP):
...
@@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
gemini_manager
.
pre_iter
()
self
.
gemini_manager
.
pre_iter
(
*
args
)
with
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
if
self
.
force_outputs_fp32
:
if
self
.
force_outputs_fp32
:
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
20e255d4
...
@@ -13,7 +13,7 @@ from colossalai.zero.utils import ZeroHook
...
@@ -13,7 +13,7 @@ from colossalai.zero.utils import ZeroHook
from
colossalai.gemini.paramhooks
import
BaseParamHookMgr
from
colossalai.gemini.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollector
,
MemStatsCollectorStatic
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
...
@@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
...
@@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy
:
str
=
'cuda'
,
tensor_placement_policy
:
str
=
'cuda'
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
reuse_fp16_shard
:
bool
=
False
,
reuse_fp16_shard
:
bool
=
False
,
user_static_memstats
:
bool
=
False
,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
assert
not
isinstance
(
module
,
ShardedModelV2
),
'Nested ShardedModelV2 is not supported.'
assert
not
isinstance
(
module
,
ShardedModelV2
),
'Nested ShardedModelV2 is not supported.'
...
@@ -110,10 +111,14 @@ class ShardedModelV2(nn.Module):
...
@@ -110,10 +111,14 @@ class ShardedModelV2(nn.Module):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
user_static_memstats
=
user_static_memstats
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
if
self
.
_use_memory_tracer
:
if
self
.
_use_memory_tracer
:
self
.
_memstats_collector
=
MemStatsCollector
()
if
self
.
user_static_memstats
:
self
.
_memstats_collector
=
MemStatsCollectorStatic
(
self
.
module
)
else
:
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
else
:
else
:
...
@@ -206,9 +211,11 @@ class ShardedModelV2(nn.Module):
...
@@ -206,9 +211,11 @@ class ShardedModelV2(nn.Module):
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_list
(
'cpu'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_list
(
'cpu'
,
'GB'
)))
f
.
write
(
'
\n
'
)
f
.
write
(
'
\n
'
)
def
_pre_forward_operations
(
self
):
def
_pre_forward_operations
(
self
,
*
args
):
# the operation will affect the memory tracer behavior in ZeroHook
# the operation will affect the memory tracer behavior in ZeroHook
if
self
.
_memstats_collector
:
if
self
.
_memstats_collector
:
if
self
.
user_static_memstats
:
self
.
init_mem_stats
(
*
args
)
self
.
_start_collect_memstats
()
self
.
_start_collect_memstats
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
...
@@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module):
...
@@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module):
p
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
p
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
self
.
_pre_forward_operations
()
self
.
_pre_forward_operations
(
*
args
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
self
.
_post_forward_operations
()
self
.
_post_forward_operations
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment