Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
85efb7ac
Unverified
Commit
85efb7ac
authored
Dec 07, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 07, 2022
Browse files
[Gemini] gemini use the runtime memory tracer (RMT) (#2099)
parent
2bf2d1cd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
120 additions
and
13 deletions
+120
-13
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+5
-3
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+1
-1
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+6
-3
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+13
-6
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+3
-0
tests/test_gemini/update/test_gemini_use_rmt.py
tests/test_gemini/update/test_gemini_use_rmt.py
+92
-0
No files found.
colossalai/gemini/gemini_mgr.py
View file @
85efb7ac
...
...
@@ -5,8 +5,9 @@ from typing import List, Optional, Tuple
import
torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
.memory_tracer
import
ChunkMemStatsCollector
,
StaticMemStatsCollector
from
.memory_tracer
import
ChunkMemStatsCollector
from
.placement_policy
import
PlacementPolicyFactory
...
...
@@ -26,13 +27,14 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
)
->
None
:
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_polocy_names
()
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
,
memstats
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_idx
:
int
=
-
1
...
...
colossalai/gemini/memory_tracer/__init__.py
View file @
85efb7ac
from
.memory_stats
import
MemStats
# isort:skip
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
# isort:skip
from
.static_memstats_collector
import
StaticMemStatsCollector
# isort:skip
from
.memory_stats
import
MemStats
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
...
...
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
View file @
85efb7ac
from
typing
import
Optional
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
...
...
@@ -7,15 +10,15 @@ from .memstats_collector import MemStatsCollector
class
ChunkMemStatsCollector
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
()
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
super
().
__init__
(
memstats
)
self
.
_chunk_manager
=
chunk_manager
# override
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
cpu_mem
=
self
.
_chunk_manager
.
total_mem
[
'cpu'
]
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_mem
)
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
85efb7ac
import
time
from
typing
import
List
from
typing
import
List
,
Optional
import
torch
...
...
@@ -22,14 +22,19 @@ class MemStatsCollector:
It has a Sampling counter which is reset after DNN training iteration.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
self
.
_mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
self
.
_memstats
=
MemStats
()
if
memstats
is
not
None
:
self
.
use_outside_memstats
=
True
self
.
_memstats
=
memstats
else
:
self
.
use_outside_memstats
=
False
self
.
_memstats
=
MemStats
()
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
"""Get max non model data memory usage of current sampling period
...
...
@@ -63,7 +68,7 @@ class MemStatsCollector:
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_mem
)
...
...
@@ -72,7 +77,7 @@ class MemStatsCollector:
def
sample_overall_data
(
self
)
->
None
:
"""Sampling non model data statistics.
"""
if
self
.
_start_flag
:
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
# overall data recording is after model data recording
if
len
(
self
.
_memstats
.
_model_data_cuda_list
)
==
0
:
return
...
...
@@ -84,9 +89,11 @@ class MemStatsCollector:
self
.
_memstats
.
append_non_model_data
(
'cuda'
)
self
.
_memstats
.
append_non_model_data
(
'cpu'
)
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
if
self
.
_start_flag
:
self
.
_sampling_time
.
append
(
time
.
time
())
def
clear
(
self
)
->
None
:
self
.
_memstats
.
clear
()
self
.
_start_flag
=
False
...
...
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
85efb7ac
...
...
@@ -35,6 +35,9 @@ class RuntimeMemTracer():
self
.
_cast_buffers_to_cuda_dtype
()
def
memstats
(
self
):
return
self
.
_memstats
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
...
...
tests/test_gemini/update/test_gemini_use_rmt.py
0 → 100644
View file @
85efb7ac
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
# run gemini use the runtime memory tracer
@
parameterize
(
'placement_policy'
,
[
'auto'
])
@
parameterize
(
'keep_gather'
,
[
False
])
@
parameterize
(
'model_name'
,
[
'bert'
,
'albert'
,
'gpt2'
])
@
parameterize
(
'use_grad_checkpoint'
,
[
False
,
True
])
def
run_gemini_use_rmt
(
placement_policy
,
keep_gather
,
model_name
:
str
,
use_grad_checkpoint
:
bool
=
False
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
'cpu'
):
model
=
model_builder
(
use_grad_checkpoint
)
print
(
f
'model_name
{
model_name
}
'
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
# mem tracing
if
i
==
0
:
run_fwd_bwd
(
runtime_mem_tracer
,
input_ids
,
label
,
criterion
,
runtime_mem_tracer
)
memstats
=
runtime_mem_tracer
.
memstats
()
runtime_tracer_non_model_data
=
runtime_mem_tracer
.
_memstats
.
_non_model_data_cuda_list
print
(
'runtime tracer: '
,
runtime_tracer_non_model_data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
pg
=
ProcessGroup
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if
i
>
1
:
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
set_seed
(
42
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
model
)
gemini_non_model_data
=
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
# print('gemini non model data:', gemini_non_model_data)
assert
len
(
gemini_non_model_data
)
==
len
(
runtime_tracer_non_model_data
),
\
f
'model_name
{
model_name
}
{
len
(
gemini_non_model_data
)
}
vs
{
len
(
runtime_tracer_non_model_data
)
}
'
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gemini_use_rmt
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_gemini_use_rmt
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_gemini_use_rmt
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment