Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
461
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1123 additions
and
681 deletions
+1123
-681
colossalai/gemini/chunk/manager.py
colossalai/gemini/chunk/manager.py
+239
-230
colossalai/gemini/chunk/search_utils.py
colossalai/gemini/chunk/search_utils.py
+46
-14
colossalai/gemini/chunk/utils.py
colossalai/gemini/chunk/utils.py
+3
-2
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+34
-10
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+10
-4
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+36
-0
colossalai/gemini/memory_tracer/memory_monitor.py
colossalai/gemini/memory_tracer/memory_monitor.py
+147
-142
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+127
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+38
-90
colossalai/gemini/memory_tracer/param_runtime_order.py
colossalai/gemini/memory_tracer/param_runtime_order.py
+42
-0
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+99
-0
colossalai/gemini/memory_tracer/static_memstats_collector.py
colossalai/gemini/memory_tracer/static_memstats_collector.py
+105
-0
colossalai/gemini/memory_tracer/utils.py
colossalai/gemini/memory_tracer/utils.py
+3
-53
colossalai/gemini/ophooks/__init__.py
colossalai/gemini/ophooks/__init__.py
+2
-3
colossalai/gemini/ophooks/_memtracer_ophook.py
colossalai/gemini/ophooks/_memtracer_ophook.py
+0
-117
colossalai/gemini/ophooks/_shard_grad_ophook.py
colossalai/gemini/ophooks/_shard_grad_ophook.py
+2
-1
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+145
-0
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+23
-13
colossalai/gemini/tensor_utils.py
colossalai/gemini/tensor_utils.py
+14
-0
colossalai/global_variables.py
colossalai/global_variables.py
+8
-2
No files found.
Too many changes to show.
To preserve performance only
461 of 461+
files are displayed.
Plain diff
Email patch
colossalai/gemini/chunk/manager.py
View file @
e532679c
import
torch
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
collections
import
deque
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
colossalai.utils
import
get_current_device
import
torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkFullError
,
TensorState
from
colossalai.tensor
import
ColoTensor
from
colossalai.
gemini.chunk
import
ChunkFullError
,
TensorState
,
Chunk
from
colossalai.
utils
import
get_current_device
class
ChunkManager
:
...
...
@@ -16,13 +17,13 @@ class ChunkManager:
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def
__init__
(
self
,
chunk_configuration
:
Dict
[
int
,
Dict
]
,
init_device
:
Optional
[
torch
.
device
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_configuration
,
init_device
:
Optional
[
torch
.
device
]
=
None
)
->
None
:
self
.
device
=
init_device
or
get_current_device
()
self
.
size_config
:
Dict
[
int
,
int
]
=
dict
()
self
.
dp_degree_chunk_size_dict
:
Dict
[
int
,
int
]
=
dict
()
self
.
kwargs_config
=
chunk_configuration
for
k
,
v
in
self
.
kwargs_config
.
items
():
self
.
size_config
[
k
]
=
v
.
pop
(
'chunk_size'
)
self
.
dp_degree_chunk_size_dict
[
k
]
=
v
.
pop
(
'chunk_size'
)
v
[
'init_device'
]
=
self
.
device
self
.
chunk_groups
:
Dict
[
str
,
Deque
]
=
dict
()
...
...
@@ -31,20 +32,28 @@ class ChunkManager:
self
.
accessed_mem
:
int
=
0
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
def
append_tensor
(
self
,
tensor
:
ColoTensor
,
group_type
:
str
,
config_key
:
int
,
pin_memory
:
bool
=
False
)
->
None
:
"""Append a tensor to a chunk.
def
register_tensor
(
self
,
tensor
:
ColoTensor
,
group_type
:
str
,
config_key
:
int
,
cpu_offload
:
bool
=
False
,
pin_memory
:
bool
=
False
)
->
None
:
"""
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
group_type: the data type of the group.
config_key: the key of the group's name, the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert
tensor
not
in
self
.
tensor_chunk_map
assert
isinstance
(
tensor
,
ColoTensor
),
"Please feed ColoTensor to this ChunkManager"
assert
config_key
in
self
.
size_config
assert
config_key
in
self
.
dp_degree_chunk_size_dict
chunk_size
=
self
.
size_config
[
config_key
]
chunk_size
=
self
.
dp_degree_chunk_size_dict
[
config_key
]
chunk_kwargs
=
self
.
kwargs_config
[
config_key
]
group_name
=
"{}_{}"
.
format
(
group_type
,
config_key
)
chunk_group
=
self
.
__get_chunk_group
(
group_name
)
...
...
@@ -67,6 +76,7 @@ class ChunkManager:
chunk_size
=
chunk_size
,
process_group
=
tensor
.
process_group
,
dtype
=
tensor
.
dtype
,
cpu_shard_init
=
cpu_offload
,
pin_memory
=
pin_memory
,
**
chunk_kwargs
,
)
...
...
@@ -206,9 +216,8 @@ class ChunkManager:
return
self
.
chunk_groups
[
group_name
]
def
__close_one_chunk
(
self
,
chunk
:
Chunk
):
device
=
get_current_device
()
if
chunk
.
keep_gathered
else
self
.
device
# keep gathered chunk in cuda
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
close_chunk
(
device
)
chunk
.
close_chunk
(
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
__sub_memroy_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
...
...
colossalai/gemini/chunk/search_utils.py
View file @
e532679c
import
math
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch.nn
as
nn
from
colossalai.gemini.memory_tracer
import
MemStats
,
OrderedParamGenerator
from
colossalai.tensor
import
ColoParameter
...
...
@@ -12,7 +13,8 @@ def in_ddp(param: nn.Parameter) -> bool:
def
_filter_exlarge_params
(
model
:
nn
.
Module
,
size_dict
:
Dict
[
int
,
List
[
int
]])
->
None
:
"""Filter those parameters whose size is too large from others.
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size
=
[
p
.
numel
()
for
p
in
model
.
parameters
()
if
in_ddp
(
p
)]
params_size_arr
=
np
.
array
(
params_size
)
...
...
@@ -39,11 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return
left
+
acc
def
clasify_params
(
model
:
nn
.
Module
)
->
Dict
[
int
,
List
[
ColoParameter
]]:
"""Clasify each parameter by its size of DP group.
def
classify_params_by_dp_degree
(
param_order
:
OrderedParamGenerator
)
->
Dict
[
int
,
List
[
ColoParameter
]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
param_order (OrderedParamGenerator): the order of param be visied
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict
:
Dict
[
int
,
List
[
ColoParameter
]]
=
dict
()
for
param
in
model
.
parameters
():
for
param
in
param_order
.
generate
():
assert
isinstance
(
param
,
ColoParameter
),
"please init model in the ColoInitContext"
if
not
in_ddp
(
param
):
continue
...
...
@@ -62,24 +73,45 @@ def search_chunk_configuration(
search_range_mb
:
float
,
search_interval_byte
:
int
,
# hidden size is the best value for the interval
min_chunk_size_mb
:
float
=
32
,
filter_exlarge_params
:
bool
=
True
)
->
Tuple
[
Dict
,
int
]:
filter_exlarge_params
:
bool
=
True
,
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
]:
"""search_chunk_configuration
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
"""
if
memstas
is
not
None
:
param_order
=
memstas
.
param_order
()
else
:
# build the param visited order right now
param_order
=
OrderedParamGenerator
()
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
search_range_byte
=
round
(
search_range_mb
*
1024
**
2
)
min_chunk_size_byte
=
round
(
min_chunk_size_mb
*
1024
**
2
)
assert
search_range_byte
>=
0
params_dict
=
clasify_params
(
mo
de
l
)
params_dict
=
clas
s
ify_params
_by_dp_degree
(
param_or
de
r
)
config_dict
:
Dict
[
int
,
Dict
]
=
dict
()
size_dict
:
Dict
[
int
,
List
[
int
]]
=
dict
()
for
key
in
params_dict
:
params_list
=
params_dict
[
key
]
for
dp_degree
in
params_dict
:
params_list
=
params_dict
[
dp_degree
]
size_list
=
[
p
.
numel
()
for
p
in
params_list
]
# let small parameters keep gathered in CUDA all the time
total_size
=
sum
(
size_list
)
if
total_size
<
min_chunk_size_byte
:
config_dict
[
key
]
=
dict
(
chunk_size
=
total_size
,
keep_gathered
=
True
)
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
total_size
,
keep_gathered
=
True
)
else
:
size_dict
[
key
]
=
size_list
size_dict
[
dp_degree
]
=
size_list
if
filter_exlarge_params
:
_filter_exlarge_params
(
model
,
size_dict
)
...
...
@@ -100,9 +132,9 @@ def search_chunk_configuration(
min_chunk_waste
=
temp_waste
best_chunk_size
=
chunk_size
for
key
in
params_dict
:
if
key
in
config_dict
:
for
dp_degree
in
params_dict
:
if
dp_degree
in
config_dict
:
continue
config_dict
[
key
]
=
dict
(
chunk_size
=
best_chunk_size
,
keep_gathered
=
False
)
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
best_chunk_size
,
keep_gathered
=
False
)
return
config_dict
,
min_chunk_waste
colossalai/gemini/chunk/utils.py
View file @
e532679c
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.chunk.search_utils
import
in_ddp
,
search_chunk_configuration
from
colossalai.gemini.memory_tracer
import
MemStats
def
init_chunk_manager
(
model
:
nn
.
Module
,
...
...
@@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module,
total_size
=
sum
(
params_sizes
)
/
1024
**
2
dist
.
barrier
()
begin
e
=
time
()
begin
=
time
()
config_dict
,
wasted_size
=
search_chunk_configuration
(
model
,
**
kwargs_dict
)
dist
.
barrier
()
end
=
time
()
span_s
=
end
-
begin
e
span_s
=
end
-
begin
wasted_size
/=
1024
**
2
if
dist
.
get_rank
()
==
0
:
...
...
colossalai/gemini/gemini_mgr.py
View file @
e532679c
import
torch
import
functools
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
List
,
Optional
,
Tuple
from
time
import
time
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
.memory_tracer
import
ChunkMemStatsCollector
from
.placement_policy
import
PlacementPolicyFactory
...
...
@@ -21,13 +25,20 @@ class GeminiManager:
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_polocy_names
()
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_policy_names
()
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
self
.
_mem_stats_collector
=
MemStatsCollectorV2
(
chunk_manager
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_premade_memstats_
=
memstats
is
not
None
self
.
_memstats
=
memstats
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
,
self
.
_memstats
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_idx
:
int
=
-
1
...
...
@@ -39,7 +50,20 @@ class GeminiManager:
self
.
_warmup
=
True
self
.
_comp_cuda_demand_time
=
0
def
pre_iter
(
self
):
def
memstats
(
self
):
"""memstats
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if
self
.
_premade_memstats_
:
return
self
.
_memstats
else
:
assert
not
self
.
_warmup
,
"Gemini Manager has memstats after warm up! Now is during warmup."
return
self
.
_mem_stats_collector
.
_memstats
def
pre_iter
(
self
,
*
args
):
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
self
.
_mem_stats_collector
.
start_collection
()
...
...
@@ -57,7 +81,7 @@ class GeminiManager:
self
.
_comp_cuda_demand_time
=
0
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
""" Adjust the layout of statefu
i
l tensor according to the information provided
""" Adjust the layout of stateful tensor
s
according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
...
...
@@ -109,9 +133,9 @@ class GeminiManager:
if
self
.
_mem_stats_collector
:
self
.
_mem_stats_collector
.
sample_overall_data
()
def
sample
_model_data
(
self
):
def
record
_model_data
_volume
(
self
):
if
self
.
_mem_stats_collector
:
self
.
_mem_stats_collector
.
sample
_model_data
()
self
.
_mem_stats_collector
.
record
_model_data
_volume
()
@
property
def
chunk_manager
(
self
):
...
...
colossalai/gemini/memory_tracer/__init__.py
View file @
e532679c
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
from
.param_runtime_order
import
OrderedParamGenerator
# isort:skip
from
.memory_stats
import
MemStats
# isort:skip
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
# isort:skip
from
.static_memstats_collector
import
StaticMemStatsCollector
# isort:skip
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'MemStats'
,
'OrderedParamGenerator'
]
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
0 → 100644
View file @
e532679c
from
typing
import
Optional
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
.memstats_collector
import
MemStatsCollector
class
ChunkMemStatsCollector
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super
().
__init__
(
memstats
)
self
.
_chunk_manager
=
chunk_manager
# override
def
record_model_data_volume
(
self
)
->
None
:
"""
record model data volumn on cuda and cpu.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
self
.
_memstats
.
record_max_cuda_model_data
(
cuda_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
self
.
_memstats
.
max_overall_cuda
colossalai/gemini/memory_tracer/memory_monitor.py
View file @
e532679c
import
json
from
abc
import
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
json
import
torch
from
colossalai.utils
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
colo_device_memory_used
,
get_current_device
class
MemoryMonitor
:
...
...
@@ -134,7 +133,13 @@ class SyncCudaMemoryMonitor(MemoryMonitor):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
reset_peak_memory_stats
()
def
finish
(
self
):
def
finish
(
self
)
->
int
:
"""
return max gpu memory used since latest `start()`.
Returns:
int: max GPU memory
"""
torch
.
cuda
.
synchronize
()
self
.
time_stamps
.
append
(
time
())
max_usage
=
torch
.
cuda
.
max_memory_allocated
()
...
...
colossalai/gemini/memory_tracer/memory_stats.py
0 → 100644
View file @
e532679c
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerator
class
MemStats
(
object
):
def
__init__
(
self
)
->
None
:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# (preop_step, List[param])
self
.
_step_param_dict
=
dict
()
# (param, List[preop_step])
self
.
_param_step_dict
=
dict
()
# (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1)
self
.
_step_nmd_dict
=
dict
()
self
.
_param_runtime_order
=
OrderedParamGenerator
()
self
.
_preop_step
=
0
self
.
_prev_overall_cuda
=
-
1
self
.
_max_overall_cuda
=
0
self
.
_prev_md_cuda
=
-
1
# old version
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
def
calc_max_cuda_non_model_data
(
self
):
if
self
.
_prev_overall_cuda
!=
-
1
and
self
.
_prev_md_cuda
!=
-
1
:
max_cuda_non_model_data
=
self
.
_prev_overall_cuda
-
self
.
_prev_md_cuda
self
.
_step_nmd_dict
[
self
.
_preop_step
-
1
]
=
max_cuda_non_model_data
# compatibility of the old version.
self
.
_non_model_data_cuda_list
.
append
(
max_cuda_non_model_data
)
def
record_max_cuda_model_data
(
self
,
val
):
self
.
_prev_md_cuda
=
val
def
record_max_cuda_overall_data
(
self
,
val
):
self
.
_prev_overall_cuda
=
val
self
.
_max_overall_cuda
=
max
(
self
.
_max_overall_cuda
,
val
)
@
property
def
max_overall_cuda
(
self
):
return
self
.
_max_overall_cuda
def
increase_preop_step
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for
p
in
param_list
:
if
p
not
in
self
.
_param_step_dict
:
self
.
_param_step_dict
[
p
]
=
[
self
.
_preop_step
]
else
:
self
.
_param_step_dict
[
p
].
append
(
self
.
_preop_step
)
self
.
_param_runtime_order
.
append
(
p
)
self
.
_step_param_dict
[
self
.
_preop_step
]
=
param_list
self
.
_preop_step
+=
1
def
param_used_step
(
self
,
param
:
torch
.
nn
.
Parameter
)
->
Optional
[
List
[
int
]]:
"""param_used_step
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if
param
not
in
self
.
_param_step_dict
:
return
None
else
:
return
self
.
_param_step_dict
[
param
]
def
param_order
(
self
):
if
self
.
_param_runtime_order
.
is_empty
():
raise
RuntimeError
else
:
return
self
.
_param_runtime_order
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
def
max_non_model_data
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_non_model_data_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_non_model_data_cpu_list
)
else
:
raise
TypeError
def
clear
(
self
):
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_param_runtime_order
.
clear
()
self
.
_step_param_dict
.
clear
()
self
.
_param_step_dict
.
clear
()
self
.
_step_nmd_dict
.
clear
()
self
.
_preop_step
=
0
self
.
_prev_overall_cuda
=
-
1
self
.
_prev_md_cuda
=
-
1
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
e532679c
import
time
from
typing
import
List
,
Optional
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.
gemini.chunk
import
ChunkManager
from
colossalai.
utils.memory
import
colo_device_memory_used
import
torch
import
time
from
typing
import
List
from
.memory_stats
import
MemStats
class
MemStatsCollector
:
...
...
@@ -21,48 +22,22 @@ class MemStatsCollector:
It has a Sampling counter which is reset after DNN training iteration.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
self
.
_mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_overall_cpu_list
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
if
memstats
is
not
None
:
self
.
use_outside_memstats
=
True
self
.
_memstats
=
memstats
else
:
raise
TypeError
self
.
use_outside_memstats
=
False
self
.
_memstats
=
MemStats
()
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
"""
Get max
non model data memory usage
of current sampling period
"""
Maximum
non model data memory usage
during the next Op run
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
...
...
@@ -72,7 +47,10 @@ class MemStatsCollector:
"""
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
next_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
assert
len
(
self
.
_memstats
.
non_model_data_list
(
device_type
))
>
self
.
_step_idx
,
\
f
"
{
len
(
self
.
_memstats
.
non_model_data_list
(
device_type
))
}
should be > than step idx
{
self
.
_step_idx
}
, "
\
f
"step total
{
self
.
_step_total
}
"
next_non_model_data
=
self
.
_memstats
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
return
next_non_model_data
...
...
@@ -86,67 +64,37 @@ class MemStatsCollector:
def
finish_collection
(
self
):
self
.
sample_overall_data
()
self
.
_step_total
=
len
(
self
.
_sampling_time
)
# self._step_total = len(self._sampling_time)
self
.
_step_total
=
len
(
self
.
_memstats
.
non_model_data_list
(
'cuda'
))
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
(
)
print
(
f
'finish_collection
{
self
.
_step_total
}
'
)
def
sample_model_data
(
self
)
->
None
:
"""Sampling
model
data
statistics.
# deprecated
def
record_
model
_
data
_volume
(
self
)
->
None
:
"""
if
self
.
_start_flag
:
Sampling model data statistics.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
self
.
_memstats
.
record_max_cuda_model_data
(
cuda_mem
)
def
sample_overall_data
(
self
)
->
None
:
"""Sampling non model data statistics.
"""
if
self
.
_start_flag
:
# overall data recording is after model data recording
if
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
self
.
_overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_overall_cpu_list
.
append
(
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
Sampling overall and non model data cuda memory statistics.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_overall
=
self
.
_mem_monitor
.
finish
()
self
.
_memstats
.
record_max_cuda_overall_data
(
cuda_overall
)
self
.
_memstats
.
calc_max_cuda_non_model_data
()
assert
len
(
self
.
_model_data_cuda_list
)
==
len
(
self
.
_overall_cuda_list
)
self
.
_mem_monitor
.
start
(
)
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
if
self
.
_start_flag
:
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
def
clear
(
self
)
->
None
:
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_memstats
.
clear
()
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
class
MemStatsCollectorV2
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall_mem_stats
(
'cuda'
))
colossalai/gemini/memory_tracer/param_runtime_order.py
0 → 100644
View file @
e532679c
from
abc
import
ABC
import
torch
class
ParamGenerator
(
ABC
):
def
append
(
self
,
param
:
torch
.
nn
.
Parameter
):
pass
def
generate
(
self
):
pass
def
clear
(
self
):
pass
class
OrderedParamGenerator
(
ParamGenerator
):
"""OrderedParamGenerator
Contain the order of parameters visited during runtime.
"""
def
__init__
(
self
)
->
None
:
self
.
param_visited_order
=
[]
def
append
(
self
,
param
:
torch
.
nn
.
Parameter
):
self
.
param_visited_order
.
append
(
param
)
def
generate
(
self
):
visited_set
=
set
()
for
p
in
self
.
param_visited_order
:
if
p
not
in
visited_set
:
yield
p
visited_set
.
add
(
p
)
del
visited_set
def
is_empty
(
self
):
return
len
(
self
.
param_visited_order
)
==
0
def
clear
(
self
):
self
.
param_visited_order
=
[]
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
0 → 100644
View file @
e532679c
import
torch.nn
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemStats
,
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ColoParamOpHookManager
__all__
=
[
'RuntimeMemTracer'
]
class
RuntimeMemTracer
():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
=
torch
.
half
):
super
().
__init__
()
self
.
module
=
module
self
.
dtype
=
dtype
self
.
_gradstat
=
GradMemStats
()
self
.
_memstats
=
MemStats
()
self
.
param_op_hook
=
ParamMemTracerHook
(
self
.
_memstats
,
self
.
_gradstat
)
self
.
grad_hook
=
GradMemTracerHook
(
self
.
_gradstat
)
self
.
cpu_param_data_dict
=
{}
for
p
in
module
.
parameters
():
p
.
data
=
p
.
data
.
to
(
dtype
)
self
.
_cast_buffers_to_cuda_dtype
()
def
parameters_in_runtime_order
(
self
):
return
self
.
_memstats
.
_param_runtime_order
.
generate
()
def
memstats
(
self
):
return
self
.
_memstats
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
def
_backup_params
(
self
):
"""
The function is called before forward. Backup model params on cpu.
"""
for
p
in
self
.
module
.
parameters
():
self
.
cpu_param_data_dict
[
p
]
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
)
self
.
cpu_param_data_dict
[
p
].
copy_
(
p
.
data
)
def
_restore_params
(
self
):
"""
This function is called after backward. Restore model params.
"""
for
p
in
self
.
module
.
parameters
():
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
,
requires_grad
=
p
.
data
.
requires_grad
)
p
.
data
.
copy_
(
self
.
cpu_param_data_dict
[
p
])
self
.
cpu_param_data_dict
.
clear
()
def
_pre_forward
(
self
):
self
.
_clear_cuda_mem_info
()
self
.
_backup_params
()
self
.
grad_hook
.
register_grad_hook
(
self
.
module
)
self
.
param_op_hook
.
mem_monitor
.
start
()
def
forward
(
self
,
*
args
,
**
kwargs
):
args
,
kwargs
=
_cast_float
(
args
,
self
.
dtype
),
_cast_float
(
kwargs
,
self
.
dtype
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
_pre_forward
()
with
ColoParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
def
backward
(
self
,
loss
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ColoParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
self
.
_post_backward
()
def
_post_backward
(
self
):
cuda_volume
=
self
.
param_op_hook
.
mem_monitor
.
finish
()
self
.
_memstats
.
record_max_cuda_overall_data
(
cuda_volume
)
# calc the last Op non model data
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
grad_hook
.
remove_grad_hook
()
self
.
_restore_params
()
def
_clear_cuda_mem_info
(
self
):
self
.
_memstats
.
clear
()
self
.
_gradstat
.
clear
()
def
_cast_buffers_to_cuda_dtype
(
self
):
for
buffer
in
self
.
module
.
buffers
():
buffer
.
data
=
buffer
.
cuda
()
if
torch
.
is_floating_point
(
buffer
):
buffer
.
data
=
buffer
.
data
.
to
(
self
.
dtype
)
colossalai/gemini/memory_tracer/static_memstats_collector.py
0 → 100644
View file @
e532679c
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.fx
import
symbolic_trace
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
from
colossalai.gemini.chunk
import
ChunkManager
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
class
StaticMemStatsCollector
(
ChunkMemStatsCollector
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
colossalai/gemini/memory_tracer/
model_data_memtracer
.py
→
colossalai/gemini/memory_tracer/
utils
.py
View file @
e532679c
from
colossalai.context.singleton_meta
import
SingletonMeta
from
typing
import
Optional
,
Tuple
import
torch
from
typing
import
Tuple
,
Optional
from
colossalai.logging
import
DistributedLogger
def
colo_model_optimizer_usage
(
optim
)
->
Tuple
[
int
,
int
]:
...
...
@@ -58,52 +57,3 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cpu_mem_usage
+=
t_cpu
return
cuda_mem_usage
,
cpu_mem_usage
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
"""
A tracer singleton to trace model data usage during runtime.
You have to register a model on the singleton first.
"""
def
__init__
(
self
)
->
None
:
self
.
_logger
=
DistributedLogger
(
"ModelDataTracer"
)
self
.
_model
=
None
self
.
_opitimizer
=
None
def
_get_mem_usage
(
self
)
->
Tuple
[
int
,
int
]:
"""
get the memory usage of the model registered.
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
cuda_use_opt
,
cpu_use_opt
=
colo_model_optimizer_usage
(
self
.
_opitimizer
)
cuda_use_model
,
cpu_use_model
=
colo_model_mem_usage
(
self
.
_model
)
return
cuda_use_opt
+
cuda_use_model
,
cpu_use_opt
+
cpu_use_model
def
register_model
(
self
,
model
)
->
None
:
if
self
.
_model
is
not
None
:
self
.
_logger
.
warning
(
"ModelDataTracer has already registered a model"
)
self
.
_model
=
model
def
register_optimizer
(
self
,
optimizer
)
->
None
:
if
self
.
_opitimizer
is
not
None
:
self
.
_logger
.
warning
(
"ModelDataTracer has already registered an optimizer"
)
self
.
_opitimizer
=
optimizer
@
property
def
cpu_usage
(
self
):
_
,
cpu_usage
=
self
.
_get_mem_usage
()
return
cpu_usage
@
property
def
cuda_usage
(
self
):
cuda_usage
,
_
=
self
.
_get_mem_usage
()
return
cuda_usage
@
property
def
both_mem_usage
(
self
):
return
self
.
_get_mem_usage
()
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
colossalai/gemini/ophooks/__init__.py
View file @
e532679c
from
.utils
import
register_ophooks_recursively
,
BaseOpHook
from
._memtracer_ophook
import
MemTracerOpHook
from
.utils
import
BaseOpHook
,
register_ophooks_recursively
__all__
=
[
"BaseOpHook"
,
"MemTracerOpHook"
,
"register_ophooks_recursively"
]
__all__
=
[
"BaseOpHook"
,
"register_ophooks_recursively"
]
colossalai/gemini/ophooks/_memtracer_ophook.py
deleted
100644 → 0
View file @
c1492e50
import
json
import
pickle
from
pathlib
import
Path
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch
from
colossalai.gemini.ophooks
import
BaseOpHook
from
colossalai.registry
import
OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
colossalai.core
import
global_context
as
gpc
from
typing
import
Union
import
math
@
OPHOOKS
.
register_module
class
MemTracerOpHook
(
BaseOpHook
):
"""
Collect GPU memory usage information
Args:
warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50.
refreshrate (int): This parameter decides the frequency of write file, defaults to 10.
data_prefix (string): The prefix of the stats data file, defaults to "memstats".
"""
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
from
colossalai.gemini.memory_tracer
import
AsyncMemoryMonitor
super
().
__init__
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_curiter
=
0
self
.
_logger
=
get_dist_logger
()
self
.
_count
=
0
self
.
_warmup
=
warmup
self
.
_refreshrate
=
refreshrate
self
.
_data_prefix
=
data_prefix
# in distributed environment
if
gpc
.
is_initialized
(
ParallelMode
.
GLOBAL
):
self
.
_rank
=
gpc
.
get_global_rank
()
else
:
self
.
_rank
=
0
def
_isvalid
(
self
,
module
)
->
bool
:
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
return
module
.
training
def
_resample
(
self
):
# calculate the average iteration time
total_time
=
(
self
.
async_mem_monitor
.
time_stamps
[
-
1
]
-
self
.
async_mem_monitor
.
time_stamps
[
0
])
avg_it_time
=
total_time
/
self
.
warmup
self
.
_logger
.
debug
(
f
"total time for
{
self
.
warmup
}
iterations is
{
total_time
}
s"
)
# adjust the sampling power
power
:
int
=
round
(
-
math
.
log
(
avg_it_time
,
10
))
+
1
self
.
_logger
.
debug
(
f
"the power is
{
power
}
"
)
self
.
async_mem_monitor
.
set_interval
(
power
)
@
property
def
refreshrate
(
self
)
->
int
:
return
self
.
_refreshrate
@
property
def
warmup
(
self
)
->
int
:
return
self
.
_warmup
@
property
def
curiter
(
self
)
->
int
:
return
self
.
_curiter
@
property
def
valid_iter
(
self
)
->
int
:
return
self
.
curiter
-
self
.
warmup
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
self
.
_isvalid
(
module
):
self
.
async_mem_monitor
.
finish
()
self
.
async_mem_monitor
.
start
()
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
self
.
_isvalid
(
module
):
self
.
async_mem_monitor
.
finish
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
if
self
.
_isvalid
(
module
):
self
.
async_mem_monitor
.
finish
()
self
.
async_mem_monitor
.
start
()
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
if
self
.
_isvalid
(
module
):
self
.
async_mem_monitor
.
finish
()
def
pre_iter
(
self
):
pass
def
post_iter
(
self
):
self
.
async_mem_monitor
.
finish
()
# in the warmup stage
if
self
.
curiter
<
self
.
warmup
:
pass
# adjust the sampling rate
elif
self
.
curiter
==
self
.
warmup
:
# use adaptive sample rate
self
.
_resample
()
# record data to log file
else
:
# every `refreshrate` times, refresh the file
if
self
.
valid_iter
!=
0
and
self
.
valid_iter
%
self
.
refreshrate
==
0
:
# output file info
self
.
_logger
.
info
(
f
"dump a memory statistics as pickle to
{
self
.
_data_prefix
}
-
{
self
.
_rank
}
.pkl"
)
home_dir
=
Path
.
home
()
with
open
(
home_dir
.
joinpath
(
f
".cache/colossal/mem-
{
self
.
_rank
}
.pkl"
),
"wb"
)
as
f
:
pickle
.
dump
(
self
.
async_mem_monitor
.
state_dict
,
f
)
self
.
_count
+=
1
self
.
_logger
.
debug
(
f
"data file has been refreshed
{
self
.
_count
}
times"
)
# finish a iteration
self
.
_curiter
+=
1
def
save_results
(
self
,
data_file
:
Union
[
str
,
Path
]):
with
open
(
data_file
,
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
async_mem_monitor
.
state_dict
))
colossalai/gemini/ophooks/_shard_grad_ophook.py
View file @
e532679c
import
torch
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
@
OPHOOKS
.
register_module
class
ShardGradHook
(
BaseOpHook
):
class
ShardGrad
MemTracer
Hook
(
BaseOpHook
):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
...
...
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
0 → 100644
View file @
e532679c
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
partial
from
typing
import
List
import
torch
from
colossalai.gemini.memory_tracer
import
MemStats
,
SyncCudaMemoryMonitor
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.tensor.param_op_hook
import
ColoParamOpHook
class
TrainingPhase
(
Enum
):
FORWARD
=
0
BACKWARD
=
1
class
GradMemStats
():
def
__init__
(
self
)
->
None
:
self
.
unreleased_grad_flag
=
{}
self
.
unreleased_grad_volume
=
0
def
clear
(
self
):
self
.
unreleased_grad_flag
.
clear
()
self
.
unreleased_grad_volume
=
0
class
GradMemTracerHook
():
def
__init__
(
self
,
grad_stats
:
GradMemStats
):
self
.
grad_hook_list
=
[]
self
.
_grad_stats
=
grad_stats
def
grad_handle
(
self
,
p
,
grad
):
assert
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
free_storage
(
grad
)
self
.
_grad_stats
.
unreleased_grad_volume
-=
grad
.
numel
()
*
grad
.
element_size
()
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
False
def
register_grad_hook
(
self
,
module
:
torch
.
nn
.
Module
):
for
p
in
module
.
parameters
():
if
p
.
requires_grad
:
self
.
grad_hook_list
.
append
(
p
.
register_hook
(
partial
(
self
.
grad_handle
,
p
)))
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
False
def
remove_grad_hook
(
self
):
for
hook
in
self
.
grad_hook_list
:
hook
.
remove
()
class
ParamMemTracerHook
(
ColoParamOpHook
):
def
__init__
(
self
,
memstats
:
MemStats
,
gradstats
:
GradMemStats
)
->
None
:
super
().
__init__
()
self
.
_training_phase
=
TrainingPhase
.
FORWARD
self
.
_memstats
=
memstats
self
.
_grad_stats
=
gradstats
self
.
mem_monitor
=
SyncCudaMemoryMonitor
()
def
_free_cuda_params
(
self
,
params
):
for
p
in
params
:
if
p
.
data
.
device
.
type
==
"cpu"
:
raise
NotImplementedError
(
"Only free cuda memory"
)
free_storage
(
p
.
data
)
def
_allocate_params_on_cuda
(
self
,
params
:
List
[
torch
.
nn
.
Parameter
]):
"""
move params to cuda
Args:
params (List[torch.nn.Parameter]): target params
Raises:
NotImplementedError: raise error when param has cpu grad
"""
for
p
in
params
:
cur_dev
=
p
.
data
.
device
.
type
if
cur_dev
==
"cpu"
:
if
p
.
grad
is
not
None
and
p
.
grad
.
device
.
type
==
"cpu"
:
raise
NotImplementedError
(
"Only run in forward propagation"
)
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
device
=
"cuda"
,
dtype
=
p
.
data
.
dtype
,
requires_grad
=
p
.
data
.
requires_grad
)
elif
cur_dev
==
"cuda"
:
alloc_storage
(
p
.
data
)
def
record_model_data_volume
(
self
,
params
):
"""
get cuda model data used by params
"""
data_volume
=
self
.
_grad_stats
.
unreleased_grad_volume
for
p
in
params
:
cur_model_data_volume
=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
data_volume
+=
cur_model_data_volume
if
self
.
_training_phase
==
TrainingPhase
.
BACKWARD
and
p
.
requires_grad
:
# add param.grad, actually param.grad is None in this time
data_volume
+=
cur_model_data_volume
if
not
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]:
self
.
_grad_stats
.
unreleased_grad_volume
+=
cur_model_data_volume
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
True
# record max non model data used for this Op
self
.
_memstats
.
record_max_cuda_model_data
(
data_volume
)
def
pre_op
(
self
,
params
):
max_cuda_used_pre_op
=
self
.
mem_monitor
.
finish
()
# record max cuda overall data for prev OP.
self
.
_memstats
.
record_max_cuda_overall_data
(
max_cuda_used_pre_op
)
# record max cuda non model data for prev OP.
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
_allocate_params_on_cuda
(
params
)
# record max cuda model data for current OP
self
.
record_model_data_volume
(
params
)
self
.
mem_monitor
.
start
()
self
.
_memstats
.
increase_preop_step
(
params
)
def
post_op
(
self
,
params
):
self
.
_free_cuda_params
(
params
)
def
pre_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
def
post_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
post_op
(
params
)
def
pre_backward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
def
post_backward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
post_op
(
params
)
@
contextmanager
def
switch_training_phase
(
self
,
training_phase
:
TrainingPhase
=
TrainingPhase
.
BACKWARD
):
old_training_phase
=
self
.
_training_phase
try
:
self
.
_training_phase
=
training_phase
yield
finally
:
self
.
_training_phase
=
old_training_phase
switch_to_backward
=
switch_training_phase
switch_to_forward
=
partial
(
switch_to_backward
,
training_phase
=
TrainingPhase
.
FORWARD
)
colossalai/gemini/placement_policy.py
View file @
e532679c
import
functools
from
abc
import
ABC
,
abstractmethod
from
time
import
time
from
typing
import
List
,
Optional
,
Tuple
,
Dict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
Type
import
functools
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
ChunkMemStatsCollector
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
class
PlacementPolicy
(
ABC
):
need_mem_stats
:
bool
=
False
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
self
.
chunk_manager
=
chunk_manager
self
.
mem_stats_collector
:
Optional
[
MemStatsCollector
V2
]
=
mem_stats_collector
self
.
mem_stats_collector
:
Optional
[
Chunk
MemStatsCollector
]
=
mem_stats_collector
@
abstractmethod
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
...
...
@@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
class
CPUPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]:
...
...
@@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
class
CUDAPlacementPolicy
(
PlacementPolicy
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
...
...
@@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
_warmup_non_model_data_ratio
:
float
=
0.8
_steady_cuda_cap_ratio
:
float
=
0.9
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
...
...
@@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats
:
bool
=
False
_accessed_memory_boundary
=
512
*
1024
**
2
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
ChunkMemStatsCollector
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
...
...
@@ -226,7 +236,7 @@ class PlacementPolicyFactory:
return
PlacementPolicyFactory
.
policies
[
policy_name
]
@
staticmethod
def
get_pol
o
cy_names
():
def
get_pol
i
cy_names
():
return
tuple
(
PlacementPolicyFactory
.
policies
.
keys
())
@
staticmethod
...
...
colossalai/gemini/tensor_utils.py
View file @
e532679c
...
...
@@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from
typing
import
Union
,
Tuple
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
def
colo_tensor_mem_usage
(
tensor
:
Union
[
torch
.
Tensor
,
StatefulTensor
])
->
Tuple
[
int
,
int
]:
if
isinstance
(
tensor
,
StatefulTensor
):
t
=
tensor
.
payload
...
...
colossalai/global_variables.py
View file @
e532679c
...
...
@@ -22,7 +22,9 @@ class TensorParallelEnv(object):
depth_3d
:
int
=
None
,
input_group_3d
=
None
,
weight_group_3d
=
None
,
output_group_3d
=
None
):
output_group_3d
=
None
,
input_x_weight_group_3d
=
None
,
output_x_weight_group_3d
=
None
):
self
.
mode
=
mode
self
.
vocab_parallel
=
vocab_parallel
self
.
parallel_input_1d
=
parallel_input_1d
...
...
@@ -33,6 +35,8 @@ class TensorParallelEnv(object):
self
.
input_group_3d
=
input_group_3d
self
.
weight_group_3d
=
weight_group_3d
self
.
output_group_3d
=
output_group_3d
self
.
input_x_weight_group_3d
=
input_x_weight_group_3d
self
.
output_x_weight_group_3d
=
output_x_weight_group_3d
def
save
(
self
):
return
dict
(
mode
=
self
.
mode
,
...
...
@@ -44,7 +48,9 @@ class TensorParallelEnv(object):
depth_3d
=
self
.
depth_3d
,
input_group_3d
=
self
.
input_group_3d
,
weight_group_3d
=
self
.
weight_group_3d
,
output_group_3d
=
self
.
output_group_3d
)
output_group_3d
=
self
.
output_group_3d
,
input_x_weight_group_3d
=
self
.
input_x_weight_group_3d
,
output_x_weight_group_3d
=
self
.
output_x_weight_group_3d
)
tensor_parallel_env
=
TensorParallelEnv
()
Prev
1
…
7
8
9
10
11
12
13
14
15
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment