Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c4739a72
Unverified
Commit
c4739a72
authored
Nov 16, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 16, 2022
Browse files
[Gemini] polish memstats collector (#1962)
parent
fea3cb66
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
200 additions
and
173 deletions
+200
-173
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+6
-5
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+9
-4
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+25
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+6
-133
colossalai/gemini/memory_tracer/static_memstats_collector.py
colossalai/gemini/memory_tracer/static_memstats_collector.py
+105
-0
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+22
-12
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+27
-19
No files found.
colossalai/gemini/gemini_mgr.py
View file @
c4739a72
...
...
@@ -6,7 +6,7 @@ import torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
.memory_tracer
.memstats_collector
import
MemStatsCollector
V2
,
MemStatsCollector
Static
from
.memory_tracer
import
Chunk
MemStatsCollector
,
Static
MemStatsCollector
from
.placement_policy
import
PlacementPolicyFactory
...
...
@@ -26,7 +26,8 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def
__init__
(
self
,
placement_policy
:
str
,
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
module
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
use_static_memstats
:
bool
=
False
)
->
None
:
...
...
@@ -35,14 +36,14 @@ class GeminiManager:
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
# self._mem_stats_collector = MemStatsCollector
V2
(chunk_manager) if policy_cls.need_mem_stats else None
# self._mem_stats_collector =
Chunk
MemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self
.
use_static_memstats
=
use_static_memstats
if
policy_cls
.
need_mem_stats
:
if
use_static_memstats
:
assert
module
is
not
None
self
.
_mem_stats_collector
=
MemStatsCollector
Static
(
module
,
chunk_manager
)
self
.
_mem_stats_collector
=
Static
MemStatsCollector
(
module
,
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
MemStatsCollector
V2
(
chunk_manager
)
self
.
_mem_stats_collector
=
Chunk
MemStatsCollector
(
chunk_manager
)
else
:
self
.
_mem_stats_collector
=
None
...
...
colossalai/gemini/memory_tracer/__init__.py
View file @
c4739a72
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
# isort:skip
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
# isort:skip
from
.static_memstats_collector
import
StaticMemStatsCollector
# isort:skip
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
0 → 100644
View file @
c4739a72
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
.memstats_collector
import
MemStatsCollector
class
ChunkMemStatsCollector
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
c4739a72
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini.chunk
import
ChunkManager
import
torch
import
torch.nn
as
nn
import
time
from
typing
import
List
,
Optional
from
typing
import
List
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
(
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
,
parameter_size
)
from
torch.fx
import
symbolic_trace
import
torch
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.utils.memory
import
colo_device_memory_used
class
MemStatsCollector
:
"""
A Memory statistic collector.
It works in two phases.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
...
...
@@ -138,121 +129,3 @@ class MemStatsCollector:
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
class
MemStatsCollectorV2
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
class
MemStatsCollectorStatic
(
MemStatsCollectorV2
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
\ No newline at end of file
colossalai/gemini/memory_tracer/static_memstats_collector.py
0 → 100644
View file @
c4739a72
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.fx
import
symbolic_trace
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
from
colossalai.gemini.chunk
import
ChunkManager
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
class
StaticMemStatsCollector
(
ChunkMemStatsCollector
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
colossalai/gemini/placement_policy.py
View file @
c4739a72
import
functools
from
abc
import
ABC
,
abstractmethod
from
time
import
time
from
typing
import
List
,
Optional
,
Tuple
,
Dict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
Type
import
functools
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
ChunkMemStatsCollector
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
class
PlacementPolicy
(
ABC
):
need_mem_stats
:
bool
=
False
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
self
.
chunk_manager
=
chunk_manager
self
.
mem_stats_collector
:
Optional
[
MemStatsCollector
V2
]
=
mem_stats_collector
self
.
mem_stats_collector
:
Optional
[
Chunk
MemStatsCollector
]
=
mem_stats_collector
@
abstractmethod
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
...
...
@@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
class
CPUPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
...
...
@@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
class
CUDAPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
...
...
@@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
_warmup_non_model_data_ratio
:
float
=
0.8
_steady_cuda_cap_ratio
:
float
=
0.9
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
...
...
@@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats
:
bool
=
False
_accessed_memory_boundary
=
512
*
1024
**
2
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
c4739a72
import
functools
import
itertools
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
,
Iterator
,
Tuple
from
copy
import
deepcopy
import
itertools
from
typing
import
Any
,
Iterator
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
,
StaticMemStatsCollector
from
colossalai.gemini.ophooks
import
register_ophooks_recursively
from
colossalai.zero.utils
import
ZeroHook
from
colossalai.gemini.paramhooks
import
BaseParamHookMgr
from
colossalai.gemini.stateful_tensor
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicy
,
TensorPlacementPolicyFactory
from
colossalai.gemini.tensor_utils
import
colo_model_data_move_to_cpu
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollector
,
MemStatsCollectorStatic
from
colossalai.utils
import
disposable
,
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.gemini.tensor_utils
import
colo_model_data_move_to_cpu
from
colossalai.gemini.stateful_tensor
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicyFactory
,
TensorPlacementPolicy
from
colossalai.zero.utils
import
ZeroHook
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
)
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
...
...
@@ -49,7 +57,7 @@ class ShardedModelV2(nn.Module):
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
...
...
@@ -60,10 +68,10 @@ class ShardedModelV2(nn.Module):
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
Defaults to 'cuda'.
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
"""
...
...
@@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module):
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
if
self
.
_use_memory_tracer
:
if
self
.
user_static_memstats
:
self
.
_memstats_collector
=
MemStatsCollector
Static
(
self
.
module
)
self
.
_memstats_collector
=
Static
MemStatsCollector
(
self
.
module
)
else
:
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment