Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
33f44121
Unverified
Commit
33f44121
authored
Dec 06, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 06, 2022
Browse files
[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)
parent
1f992058
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
193 additions
and
139 deletions
+193
-139
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+4
-3
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+94
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+13
-50
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-12
tests/test_zero/test_mem_collector.py
tests/test_zero/test_mem_collector.py
+77
-74
No files found.
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
View file @
33f44121
...
@@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
...
@@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
super
().
__init__
()
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
self
.
_chunk_manager
=
chunk_manager
# override
def
sample_model_data
(
self
)
->
None
:
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""Sampling model data statistics.
"""
"""
if
self
.
_start_flag
:
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data
_
cuda
_list
.
append
(
cuda_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cuda
'
,
cuda_mem
)
self
.
_model_data
_
cpu
_list
.
append
(
cpu_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cpu
'
,
cpu_mem
)
@
property
@
property
def
cuda_margin_mem
(
self
)
->
float
:
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall
_mem
_
stats
(
'cuda'
)
)
return
colo_device_memory_capacity
(
get_current_device
())
-
self
.
_memstats
.
max_overall_cuda
(
'cuda'
)
colossalai/gemini/memory_tracer/memory_stats.py
0 → 100644
View file @
33f44121
from
typing
import
Any
,
Dict
,
List
class
MemStats
(
object
):
def
__init__
(
self
)
->
None
:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
self
.
param_non_model_data_map
:
Dict
(
Any
,
List
[
int
])
=
{}
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_overall_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
self
.
_overall_cpu_list
.
append
(
val
)
else
:
raise
TypeError
def
append_model_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_model_data_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
self
.
_model_data_cpu_list
.
append
(
val
)
else
:
raise
TypeError
def
append_non_model_data
(
self
,
device_type
:
str
):
if
device_type
==
'cuda'
:
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
elif
device_type
==
'cpu'
:
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
else
:
raise
TypeError
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_overall_cpu_list
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
def
max_non_model_data
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_non_model_data_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_non_model_data_cpu_list
)
else
:
raise
TypeError
def
max_overall_cuda
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_overall_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_overall_cpu_list
)
else
:
raise
TypeError
def
clear
(
self
):
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
33f44121
...
@@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
...
@@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory
import
colo_device_memory_used
from
.memory_stats
import
MemStats
class
MemStatsCollector
:
class
MemStatsCollector
:
"""
"""
...
@@ -22,43 +24,12 @@ class MemStatsCollector:
...
@@ -22,43 +24,12 @@ class MemStatsCollector:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_sampling_time
=
[]
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_idx
=
0
self
.
_step_total
=
0
self
.
_step_total
=
0
self
.
_memstats
=
MemStats
()
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_overall_cpu_list
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
"""Get max non model data memory usage of current sampling period
"""Get max non model data memory usage of current sampling period
...
@@ -71,7 +42,7 @@ class MemStatsCollector:
...
@@ -71,7 +42,7 @@ class MemStatsCollector:
"""
"""
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
next_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
next_non_model_data
=
self
.
_memstats
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
return
next_non_model_data
return
next_non_model_data
...
@@ -95,37 +66,29 @@ class MemStatsCollector:
...
@@ -95,37 +66,29 @@ class MemStatsCollector:
if
self
.
_start_flag
:
if
self
.
_start_flag
:
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
self
.
_model_data
_
cuda
_list
.
append
(
cuda_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cuda
'
,
cuda_mem
)
self
.
_model_data
_
cpu
_list
.
append
(
cpu_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cpu
'
,
cpu_mem
)
def
sample_overall_data
(
self
)
->
None
:
def
sample_overall_data
(
self
)
->
None
:
"""Sampling non model data statistics.
"""Sampling non model data statistics.
"""
"""
if
self
.
_start_flag
:
if
self
.
_start_flag
:
# overall data recording is after model data recording
# overall data recording is after model data recording
if
len
(
self
.
_model_data_cuda_list
)
==
0
:
if
len
(
self
.
_
memstats
.
_
model_data_cuda_list
)
==
0
:
return
return
self
.
_
overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_
memstats
.
append_overall_data
(
'cuda'
,
self
.
_mem_monitor
.
finish
())
self
.
_
overall_cpu_list
.
append
(
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
self
.
_
memstats
.
append_overall_data
(
'cpu'
,
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
assert
len
(
self
.
_model_data_cuda_list
)
==
len
(
self
.
_overall_cuda_list
)
assert
len
(
self
.
_
memstats
.
_
model_data_cuda_list
)
==
len
(
self
.
_
memstats
.
_
overall_cuda_list
)
self
.
_
non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data
_
cuda
_list
[
-
1
]
)
self
.
_
memstats
.
append_non
_model_data
(
'
cuda
'
)
self
.
_
non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data
_
cpu
_list
[
-
1
]
)
self
.
_
memstats
.
append_non
_model_data
(
'
cpu
'
)
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
self
.
_mem_monitor
.
start
()
def
clear
(
self
)
->
None
:
def
clear
(
self
)
->
None
:
self
.
_model_data_cuda_list
=
[]
self
.
_memstats
.
clear
()
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_idx
=
0
self
.
_step_total
=
0
self
.
_step_total
=
0
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
33f44121
...
@@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
...
@@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy
:
str
=
'cuda'
,
tensor_placement_policy
:
str
=
'cuda'
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
reuse_fp16_shard
:
bool
=
False
,
reuse_fp16_shard
:
bool
=
False
,
user_static_memstats
:
bool
=
False
,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
assert
not
isinstance
(
module
,
ShardedModelV2
),
'Nested ShardedModelV2 is not supported.'
assert
not
isinstance
(
module
,
ShardedModelV2
),
'Nested ShardedModelV2 is not supported.'
...
@@ -119,13 +118,9 @@ class ShardedModelV2(nn.Module):
...
@@ -119,13 +118,9 @@ class ShardedModelV2(nn.Module):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
user_static_memstats
=
user_static_memstats
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
if
self
.
_use_memory_tracer
:
if
self
.
_use_memory_tracer
:
if
self
.
user_static_memstats
:
self
.
_memstats_collector
=
StaticMemStatsCollector
(
self
.
module
)
else
:
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
...
@@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
...
@@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
f
.
write
(
f
'cuda reserved
{
torch
.
cuda
.
memory_reserved
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
f
'cuda reserved
{
torch
.
cuda
.
memory_reserved
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
f
'cuda max allocated
{
torch
.
cuda
.
max_memory_allocated
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
f
'cuda max allocated
{
torch
.
cuda
.
max_memory_allocated
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
'CUDA model data (GB)
\n
'
)
f
.
write
(
'CUDA model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
model_data_list
(
'cuda'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
model_data_list
(
'cuda'
)))
f
.
write
(
'
\n
'
)
f
.
write
(
'
\n
'
)
f
.
write
(
'CUDA non model data (GB)
\n
'
)
f
.
write
(
'CUDA non model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_list
(
'cuda'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)))
f
.
write
(
'CPU non model data (GB)
\n
'
)
f
.
write
(
'CPU non model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_list
(
'cpu'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cpu'
)))
f
.
write
(
'
\n
'
)
f
.
write
(
'
\n
'
)
def
_pre_forward_operations
(
self
,
*
args
):
def
_pre_forward_operations
(
self
,
*
args
):
# the operation will affect the memory tracer behavior in ZeroHook
# the operation will affect the memory tracer behavior in ZeroHook
if
self
.
_memstats_collector
:
if
self
.
_memstats_collector
:
if
self
.
user_static_memstats
:
self
.
init_mem_stats
(
*
args
)
self
.
_start_collect_memstats
()
self
.
_start_collect_memstats
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
...
@@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
...
@@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
_cuda_margin_space
=
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
_memstats_collector
.
overall_mem_stats
(
'cuda'
))
self
.
_memstats_collector
.
_memstats
.
overall_mem_stats
(
'cuda'
))
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_post_backward_operations
(
self
)
->
None
:
def
_post_backward_operations
(
self
)
->
None
:
...
...
tests/test_zero/test_mem_collector.py
View file @
33f44121
import
torch
from
functools
import
partial
import
colossalai
import
pytest
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.utils
import
free_port
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.testing
import
rerun_if_address_is_in_use
from
functools
import
partial
class
MyTestModel
(
torch
.
nn
.
Module
):
class
MyTestModel
(
torch
.
nn
.
Module
):
...
@@ -50,10 +52,11 @@ def run_mem_collector_testing():
...
@@ -50,10 +52,11 @@ def run_mem_collector_testing():
loss
=
torch
.
mean
(
output
)
loss
=
torch
.
mean
(
output
)
model
.
backward
(
loss
)
model
.
backward
(
loss
)
cuda_model_data_list
=
model
.
_memstats_collector
.
model_data_list
(
'cuda'
)
cuda_model_data_list
=
model
.
_memstats_collector
.
_memstats
.
model_data_list
(
'cuda'
)
assert
cuda_model_data_list
==
[
1311744
,
1836032
,
1836032
,
1311744
,
1836032
,
1836032
]
assert
cuda_model_data_list
==
[
1311744
,
1836032
,
1836032
,
1311744
,
1836032
,
1836032
]
cuda_non_model_data_list
=
model
.
_memstats_collector
.
non_model_data_list
(
'cuda'
)
cuda_non_model_data_list
=
model
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
print
(
'cuda_non_model_data_list '
,
cuda_non_model_data_list
)
assert
cuda_non_model_data_list
[
0
]
>
cuda_non_model_data_list
[
1
]
assert
cuda_non_model_data_list
[
0
]
>
cuda_non_model_data_list
[
1
]
assert
cuda_non_model_data_list
[
-
2
]
>
cuda_non_model_data_list
[
-
1
]
assert
cuda_non_model_data_list
[
-
2
]
>
cuda_non_model_data_list
[
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment