Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c89c66a8
Unverified
Commit
c89c66a8
authored
Dec 14, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 14, 2022
Browse files
[Gemini] update API of the chunkmemstatscollector. (#2129)
parent
2938edf4
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
32 additions
and
163 deletions
+32
-163
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+1
-1
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+11
-4
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+6
-61
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+9
-15
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-3
colossalai/zero/utils/gemini_hook.py
colossalai/zero/utils/gemini_hook.py
+2
-0
tests/test_gemini/update/test_gemini_use_rmt.py
tests/test_gemini/update/test_gemini_use_rmt.py
+1
-2
tests/test_zero/test_mem_collector.py
tests/test_zero/test_mem_collector.py
+0
-77
No files found.
colossalai/gemini/gemini_mgr.py
View file @
c89c66a8
...
...
@@ -55,7 +55,7 @@ class GeminiManager:
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter
,
you can not access the memstats before warmup iteration finishes.
Note, for the latter
,
you can not access the memstats before warmup iteration finishes.
"""
if
self
.
_premade_memstats_
:
return
self
.
_memstats
...
...
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
View file @
c89c66a8
...
...
@@ -11,18 +11,25 @@ from .memstats_collector import MemStatsCollector
class
ChunkMemStatsCollector
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super
().
__init__
(
memstats
)
self
.
_chunk_manager
=
chunk_manager
# override
def
record_model_data_volume
(
self
)
->
None
:
"""Sampling model data statistics.
"""
record model data volumn on cuda and cpu.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_mem
)
self
.
_memstats
.
append_model_data
(
'cpu'
,
cpu_mem
)
self
.
_memstats
.
record_max_cuda_model_data
(
cuda_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
...
...
colossalai/gemini/memory_tracer/memory_stats.py
View file @
c89c66a8
...
...
@@ -22,6 +22,7 @@ class MemStats(object):
self
.
_preop_step
=
0
self
.
_prev_overall_cuda
=
-
1
self
.
_max_overall_cuda
=
0
self
.
_prev_md_cuda
=
-
1
# old version
...
...
@@ -46,6 +47,11 @@ class MemStats(object):
def
record_max_cuda_overall_data
(
self
,
val
):
self
.
_prev_overall_cuda
=
val
self
.
_max_overall_cuda
=
max
(
self
.
_max_overall_cuda
,
val
)
@
property
def
max_overall_cuda
(
self
):
return
self
.
_max_overall_cuda
def
increase_preop_step
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
]):
"""
...
...
@@ -85,67 +91,6 @@ class MemStats(object):
else
:
return
self
.
_param_runtime_order
## APIs to be depracated
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_overall_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
self
.
_overall_cpu_list
.
append
(
val
)
else
:
raise
TypeError
def
append_model_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_model_data_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
self
.
_model_data_cpu_list
.
append
(
val
)
else
:
raise
TypeError
def
last_model_data
(
self
,
device_type
:
str
):
if
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
None
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
[
-
1
]
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
[
-
1
]
else
:
raise
TypeError
def
append_non_model_data
(
self
,
device_type
:
str
,
val
=
None
):
if
device_type
==
'cuda'
:
if
val
is
None
:
if
len
(
self
.
_overall_cuda_list
)
==
0
or
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
else
:
self
.
_non_model_data_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
if
val
is
None
:
if
len
(
self
.
_overall_cuda_list
)
==
0
or
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
else
:
self
.
_non_model_data_cuda_list
.
append
(
val
)
else
:
raise
TypeError
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_overall_cpu_list
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
c89c66a8
...
...
@@ -59,6 +59,7 @@ class MemStatsCollector:
return
[
t
-
self
.
_sampling_time
[
0
]
for
t
in
self
.
_sampling_time
]
def
start_collection
(
self
):
print
(
'start collection'
)
self
.
_start_flag
=
True
self
.
_mem_monitor
.
start
()
...
...
@@ -68,31 +69,24 @@ class MemStatsCollector:
self
.
_step_total
=
len
(
self
.
_memstats
.
non_model_data_list
(
'cuda'
))
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
print
(
f
'finish_collection
{
self
.
_step_total
}
'
)
# deprecated
def
record_model_data_volume
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_mem
)
self
.
_memstats
.
append_model_data
(
'cpu'
,
cpu_mem
)
raise
NotImplementedError
(
"MemStatsCollector has not implemented record_model_data_volume"
)
def
sample_overall_data
(
self
)
->
None
:
"""Sampling non model data statistics.
"""
Sampling overall and non model data cuda memory statistics.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
# overall data recording is after model data recording
if
len
(
self
.
_memstats
.
_model_data_cuda_list
)
==
0
:
return
self
.
_memstats
.
append_overall_data
(
'cuda'
,
self
.
_mem_monitor
.
finish
())
self
.
_memstats
.
append_overall_data
(
'cpu'
,
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
assert
len
(
self
.
_memstats
.
_model_data_cuda_list
)
==
len
(
self
.
_memstats
.
_overall_cuda_list
)
cuda_overall
=
self
.
_mem_monitor
.
finish
()
self
.
_memstats
.
record_max_cuda_overall_data
(
cuda_overall
)
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
_memstats
.
append_non_model_data
(
'cuda'
)
self
.
_memstats
.
append_non_model_data
(
'cpu'
)
self
.
_mem_monitor
.
start
()
if
self
.
_start_flag
:
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
c89c66a8
...
...
@@ -206,7 +206,6 @@ class ShardedModelV2(nn.Module):
f
.
write
(
f
'cuda reserved
{
torch
.
cuda
.
memory_reserved
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
f
'cuda max allocated
{
torch
.
cuda
.
max_memory_allocated
(
get_current_device
())
/
1e9
}
GB
\n
'
)
f
.
write
(
'CUDA model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
model_data_list
(
'cuda'
)))
f
.
write
(
'
\n
'
)
f
.
write
(
'CUDA non model data (GB)
\n
'
)
f
.
write
(
str
(
self
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)))
...
...
@@ -256,8 +255,8 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
colo_device_memory_capacity
(
get_current_device
())
-
max
(
self
.
_memstats_collector
.
_memstats
.
overall_
mem_stats
(
'
cuda
'
))
self
.
_cuda_margin_space
=
colo_device_memory_capacity
(
get_current_device
())
-
self
.
_memstats_collector
.
_memstats
.
max_
overall_cuda
@
torch
.
no_grad
()
def
_post_backward_operations
(
self
)
->
None
:
...
...
colossalai/zero/utils/gemini_hook.py
View file @
c89c66a8
...
...
@@ -32,6 +32,8 @@ class GeminiZeROHook(ColoParamOpHook):
self
.
_gemini_manager
.
adjust_layout
(
chunks
)
for
chunk
in
chunks
:
self
.
_chunk_manager
.
access_chunk
(
chunk
)
# record cuda model data of the current OP
self
.
_gemini_manager
.
record_model_data_volume
()
def
post_op
(
self
,
params
):
...
...
tests/test_gemini/update/test_gemini_use_rmt.py
View file @
c89c66a8
...
...
@@ -57,11 +57,10 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
if
model_name
==
'repeated_computed_layers'
:
for
idx
,
p
in
enumerate
(
model
.
parameters
()):
step_list
=
memstats
.
param_used_
time
step
(
p
)
step_list
=
memstats
.
param_used_step
(
p
)
if
idx
<
4
:
assert
len
(
step_list
)
==
4
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
...
...
tests/test_zero/test_mem_collector.py
deleted
100644 → 0
View file @
2938edf4
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
class
MyTestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
512
,
512
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1024
,
512
))
self
.
proj2
=
nn
.
Linear
(
1024
,
512
)
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
F
.
linear
(
x
,
self
.
weight
)
x
=
self
.
proj2
(
x
)
return
x
def
run_mem_collector_testing
():
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
fraction
=
(
50
*
1024
**
2
)
/
cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction
(
fraction
)
shard_strategy
=
BucketTensorShardStrategy
()
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
MyTestModel
()
model
=
ShardedModelV2
(
module
=
model
,
shard_strategy
=
shard_strategy
,
reduce_scatter_bucket_size_mb
=
1
,
tensor_placement_policy
=
'auto'
)
data
=
torch
.
randn
(
2
,
512
,
device
=
get_current_device
())
output
=
model
(
data
)
loss
=
torch
.
mean
(
output
)
model
.
backward
(
loss
)
cuda_model_data_list
=
model
.
_memstats_collector
.
_memstats
.
model_data_list
(
'cuda'
)
assert
cuda_model_data_list
==
[
1311744
,
1836032
,
1836032
,
1311744
,
1836032
,
1836032
]
cuda_non_model_data_list
=
model
.
_memstats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
print
(
'cuda_non_model_data_list '
,
cuda_non_model_data_list
)
assert
cuda_non_model_data_list
[
0
]
>
cuda_non_model_data_list
[
1
]
assert
cuda_non_model_data_list
[
-
2
]
>
cuda_non_model_data_list
[
-
1
]
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_mem_collector_testing
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mem_collector
(
world_size
=
2
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mem_collector
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment