Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c4739a72
Unverified
Commit
c4739a72
authored
Nov 16, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 16, 2022
Browse files
[Gemini] polish memstats collector (#1962)
parent
fea3cb66
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
200 additions
and
173 deletions
+200
-173
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+6
-5
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+9
-4
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+25
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+6
-133
colossalai/gemini/memory_tracer/static_memstats_collector.py
colossalai/gemini/memory_tracer/static_memstats_collector.py
+105
-0
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+22
-12
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+27
-19
No files found.
colossalai/gemini/gemini_mgr.py
View file @
c4739a72
...
...
@@ -6,7 +6,7 @@ import torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
.memory_tracer
.memstats_collector
import
MemStatsCollector
V2
,
MemStatsCollector
Static
from
.memory_tracer
import
Chunk
MemStatsCollector
,
Static
MemStatsCollector
from
.placement_policy
import
PlacementPolicyFactory
...
...
@@ -26,7 +26,8 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def
__init__
(
self
,
placement_policy
:
str
,
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
module
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
use_static_memstats
:
bool
=
False
)
->
None
:
...
...
@@ -35,14 +36,14 @@ class GeminiManager:
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
# self._mem_stats_collector = MemStatsCollector
V2
(chunk_manager) if policy_cls.need_mem_stats else None
# self._mem_stats_collector =
Chunk
MemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self
.
use_static_memstats
=
use_static_memstats
if
policy_cls
.
need_mem_stats
:
if
use_static_memstats
:
assert
module
is
not
None
self
.
_mem_stats_collector
=
MemStatsCollector
Static
(
module
,
chunk_manager
)
self
.
_mem_stats_collector
=
Static
MemStatsCollector
(
module
,
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
MemStatsCollector
V2
(
chunk_manager
)
self
.
_mem_stats_collector
=
Chunk
MemStatsCollector
(
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
None
...
...
colossalai/gemini/memory_tracer/__init__.py
View file @
c4739a72
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
# isort:skip
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
# isort:skip
from
.static_memstats_collector
import
StaticMemStatsCollector
# isort:skip
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
0 → 100644
View file @
c4739a72
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
.memstats_collector
import
MemStatsCollector
class
ChunkMemStatsCollector
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
c4739a72
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini.chunk
import
ChunkManager
import
torch
import
torch.nn
as
nn
import
time
from
typing
import
List
,
Optional
from
typing
import
List
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
(
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
,
parameter_size
)
from
torch.fx
import
symbolic_trace
import
torch
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.utils.memory
import
colo_device_memory_used
class
MemStatsCollector
:
...
...
@@ -138,121 +129,3 @@ class MemStatsCollector:
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
class
MemStatsCollectorV2
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
class
MemStatsCollectorStatic
(
MemStatsCollectorV2
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
\ No newline at end of file
colossalai/gemini/memory_tracer/static_memstats_collector.py
0 → 100644
View file @
c4739a72
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.fx
import
symbolic_trace
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
from
colossalai.gemini.chunk
import
ChunkManager
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
class
StaticMemStatsCollector
(
ChunkMemStatsCollector
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
colossalai/gemini/placement_policy.py
View file @
c4739a72
import
functools
from
abc
import
ABC
,
abstractmethod
from
time
import
time
from
typing
import
List
,
Optional
,
Tuple
,
Dict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
Type
import
functools
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
ChunkMemStatsCollector
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
class
PlacementPolicy
(
ABC
):
need_mem_stats
:
bool
=
False
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
self
.
chunk_manager
=
chunk_manager
self
.
mem_stats_collector
:
Optional
[
MemStatsCollector
V2
]
=
mem_stats_collector
self
.
mem_stats_collector
:
Optional
[
Chunk
MemStatsCollector
]
=
mem_stats_collector
@
abstractmethod
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
...
...
@@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
class
CPUPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
...
...
@@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
class
CUDAPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
...
...
@@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
_warmup_non_model_data_ratio
:
float
=
0.8
_steady_cuda_cap_ratio
:
float
=
0.9
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
...
...
@@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats
:
bool
=
False
_accessed_memory_boundary
=
512
*
1024
**
2
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
c4739a72
import
functools
import
itertools
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
,
Iterator
,
Tuple
from
copy
import
deepcopy
import
itertools
from
typing
import
Any
,
Iterator
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
,
StaticMemStatsCollector
from
colossalai.gemini.ophooks
import
register_ophooks_recursively
from
colossalai.zero.utils
import
ZeroHook
from
colossalai.gemini.paramhooks
import
BaseParamHookMgr
from
colossalai.gemini.stateful_tensor
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicy
,
TensorPlacementPolicyFactory
from
colossalai.gemini.tensor_utils
import
colo_model_data_move_to_cpu
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollector
,
MemStatsCollectorStatic
from
colossalai.utils
import
disposable
,
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.gemini.tensor_utils
import
colo_model_data_move_to_cpu
from
colossalai.gemini.stateful_tensor
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicyFactory
,
TensorPlacementPolicy
from
colossalai.zero.utils
import
ZeroHook
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
)
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
...
...
@@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module):
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
if
self
.
_use_memory_tracer
:
if
self
.
user_static_memstats
:
self
.
_memstats_collector
=
MemStatsCollector
Static
(
self
.
module
)
self
.
_memstats_collector
=
Static
MemStatsCollector
(
self
.
module
)
else
:
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment