Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
33f44121
Unverified
Commit
33f44121
authored
Dec 06, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 06, 2022
Browse files
[Gemini] use MemStats to store the tracing data. Seperate it from Collector. (#2084)
parent
1f992058
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
193 additions
and
139 deletions
+193
-139
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+4
-3
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+94
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+13
-50
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-12
tests/test_zero/test_mem_collector.py
tests/test_zero/test_mem_collector.py
+77
-74
No files found.
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
View file @
33f44121
...
...
@@ -11,15 +11,16 @@ class ChunkMemStatsCollector(MemStatsCollector):
super
().
__init__
()
self
.
_chunk_manager
=
chunk_manager
# override
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_model_data
_
cuda
_list
.
append
(
cuda_mem
)
self
.
_model_data
_
cpu
_list
.
append
(
cpu_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cuda
'
,
cuda_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cpu
'
,
cpu_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
overall
_mem
_
stats
(
'cuda'
)
)
return
colo_device_memory_capacity
(
get_current_device
())
-
self
.
_memstats
.
max_overall_cuda
(
'cuda'
)
colossalai/gemini/memory_tracer/memory_stats.py
0 → 100644
View file @
33f44121
from
typing
import
Any
,
Dict
,
List
class
MemStats
(
object
):
def
__init__
(
self
)
->
None
:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
self
.
param_non_model_data_map
:
Dict
(
Any
,
List
[
int
])
=
{}
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_overall_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
self
.
_overall_cpu_list
.
append
(
val
)
else
:
raise
TypeError
def
append_model_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_model_data_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
self
.
_model_data_cpu_list
.
append
(
val
)
else
:
raise
TypeError
def
append_non_model_data
(
self
,
device_type
:
str
):
if
device_type
==
'cuda'
:
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
elif
device_type
==
'cpu'
:
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
else
:
raise
TypeError
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_overall_cpu_list
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
def
max_non_model_data
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_non_model_data_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_non_model_data_cpu_list
)
else
:
raise
TypeError
def
max_overall_cuda
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_overall_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_overall_cpu_list
)
else
:
raise
TypeError
def
clear
(
self
):
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
33f44121
...
...
@@ -7,6 +7,8 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.utils.memory
import
colo_device_memory_used
from
.memory_stats
import
MemStats
class
MemStatsCollector
:
"""
...
...
@@ -22,43 +24,12 @@ class MemStatsCollector:
def
__init__
(
self
)
->
None
:
self
.
_mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_overall_cpu_list
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
self
.
_memstats
=
MemStats
()
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
"""Get max non model data memory usage of current sampling period
...
...
@@ -71,7 +42,7 @@ class MemStatsCollector:
"""
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
next_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
next_non_model_data
=
self
.
_memstats
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
return
next_non_model_data
...
...
@@ -95,37 +66,29 @@ class MemStatsCollector:
if
self
.
_start_flag
:
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
self
.
_model_data
_
cuda
_list
.
append
(
cuda_mem
)
self
.
_model_data
_
cpu
_list
.
append
(
cpu_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cuda
'
,
cuda_mem
)
self
.
_
memstats
.
append_
model_data
(
'
cpu
'
,
cpu_mem
)
def
sample_overall_data
(
self
)
->
None
:
"""Sampling non model data statistics.
"""
if
self
.
_start_flag
:
# overall data recording is after model data recording
if
len
(
self
.
_model_data_cuda_list
)
==
0
:
if
len
(
self
.
_
memstats
.
_
model_data_cuda_list
)
==
0
:
return
self
.
_
overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_
overall_cpu_list
.
append
(
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
self
.
_
memstats
.
append_overall_data
(
'cuda'
,
self
.
_mem_monitor
.
finish
())
self
.
_
memstats
.
append_overall_data
(
'cpu'
,
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
assert
len
(
self
.
_model_data_cuda_list
)
==
len
(
self
.
_overall_cuda_list
)
assert
len
(
self
.
_
memstats
.
_
model_data_cuda_list
)
==
len
(
self
.
_
memstats
.
_
overall_cuda_list
)
self
.
_
non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data
_
cuda
_list
[
-
1
]
)
self
.
_
non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data
_
cpu
_list
[
-
1
]
)
self
.
_
memstats
.
append_non
_model_data
(
'
cuda
'
)
self
.
_
memstats
.
append_non
_model_data
(
'
cpu
'
)
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
def
clear
(
self
)
->
None
:
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_memstats
.
clear
()
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
33f44121
...
...
@@ -85,7 +85,6 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy
:
str
=
'cuda'
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
reuse_fp16_shard
:
bool
=
False
,
user_static_memstats
:
bool
=
False
,
*
args
,
**
kwargs
):
assert
not
isinstance
(
module
,
ShardedModelV2
),
'Nested ShardedModelV2 is not supported.'
...
...
@@ -119,14 +118,10 @@ class ShardedModelV2(nn.Module):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
shard_strategy
=
shard_strategy
self
.
user_static_memstats
=
user_static_memstats
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
if
self
.
_use_memory_tracer
:
if
self
.
user_static_memstats
:
self
.
_memstats_collector
=
StaticMemStatsCollector
(
self
.
module
)
else
:
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
else
:
...
...
@@ -211,19 +206,17 @@ class ShardedModelV2(nn.Module):
f
.
write
(
f
'cuda reserved
{
torch
.
cuda
.
memory_reserved
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
f
'cuda max allocated
{
torch
.
cuda
.
max_memory_allocated
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
'CUDA model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
model_data_list
(
'cuda'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
model_data_list
(
'cuda'
)))
f
.
write
(
'
\n
'
)
f
.
write
(
'CUDA non model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_list
(
'cuda'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)))
f
.
write
(
'CPU non model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_list
(
'cpu'
,
'GB'
)))
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cpu'
)))
f
.
write
(
'
\n
'
)
def
_pre_forward_operations
(
self
,
*
args
):
# the operation will affect the memory tracer behavior in ZeroHook
if
self
.
_memstats_collector
:
if
self
.
user_static_memstats
:
self
.
init_mem_stats
(
*
args
)
self
.
_start_collect_memstats
()
for
p
in
self
.
module
.
parameters
():
...
...
@@ -264,7 +257,7 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
_memstats_collector
.
overall_mem_stats
(
'cuda'
))
self
.
_memstats_collector
.
_memstats
.
overall_mem_stats
(
'cuda'
))
@
torch
.
no_grad
()
def
_post_backward_operations
(
self
)
->
None
:
...
...
tests/test_zero/test_mem_collector.py
View file @
33f44121
import
torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
functools
import
partial
class
MyTestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
512
,
512
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1024
,
512
))
self
.
proj2
=
nn
.
Linear
(
1024
,
512
)
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
F
.
linear
(
x
,
self
.
weight
)
x
=
self
.
proj2
(
x
)
return
x
def
run_mem_collector_testing
():
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
fraction
=
(
50
*
1024
**
2
)
/
cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction
(
fraction
)
shard_strategy
=
BucketTensorShardStrategy
()
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
MyTestModel
()
model
=
ShardedModelV2
(
module
=
model
,
shard_strategy
=
shard_strategy
,
reduce_scatter_bucket_size_mb
=
1
,
tensor_placement_policy
=
'auto'
)
data
=
torch
.
randn
(
2
,
512
,
device
=
get_current_device
())
output
=
model
(
data
)
loss
=
torch
.
mean
(
output
)
model
.
backward
(
loss
)
cuda_model_data_list
=
model
.
_memstats_collector
.
model_data_list
(
'cuda'
)
assert
cuda_model_data_list
==
[
1311744
,
1836032
,
1836032
,
1311744
,
1836032
,
1836032
]
cuda_non_model_data_list
=
model
.
_memstats_collector
.
non_model_data_list
(
'cuda'
)
assert
cuda_non_model_data_list
[
0
]
>
cuda_non_model_data_list
[
1
]
assert
cuda_non_model_data_list
[
-
2
]
>
cuda_non_model_data_list
[
-
1
]
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_mem_collector_testing
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mem_collector
(
world_size
=
2
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mem_collector
()
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
class
MyTestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
512
,
512
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1024
,
512
))
self
.
proj2
=
nn
.
Linear
(
1024
,
512
)
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
F
.
linear
(
x
,
self
.
weight
)
x
=
self
.
proj2
(
x
)
return
x
def
run_mem_collector_testing
():
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
fraction
=
(
50
*
1024
**
2
)
/
cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction
(
fraction
)
shard_strategy
=
BucketTensorShardStrategy
()
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
MyTestModel
()
model
=
ShardedModelV2
(
module
=
model
,
shard_strategy
=
shard_strategy
,
reduce_scatter_bucket_size_mb
=
1
,
tensor_placement_policy
=
'auto'
)
data
=
torch
.
randn
(
2
,
512
,
device
=
get_current_device
())
output
=
model
(
data
)
loss
=
torch
.
mean
(
output
)
model
.
backward
(
loss
)
cuda_model_data_list
=
model
.
_memstats_collector
.
_memstats
.
model_data_list
(
'cuda'
)
assert
cuda_model_data_list
==
[
1311744
,
1836032
,
1836032
,
1311744
,
1836032
,
1836032
]
cuda_non_model_data_list
=
model
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
print
(
'cuda_non_model_data_list '
,
cuda_non_model_data_list
)
assert
cuda_non_model_data_list
[
0
]
>
cuda_non_model_data_list
[
1
]
assert
cuda_non_model_data_list
[
-
2
]
>
cuda_non_model_data_list
[
-
1
]
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_mem_collector_testing
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mem_collector
(
world_size
=
2
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mem_collector
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment