Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9214d1fe
Unverified
Commit
9214d1fe
authored
Dec 12, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 12, 2022
Browse files
[Gemini] chunk init using runtime visited param order (#2115)
parent
e7d3afc9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
77 additions
and
29 deletions
+77
-29
colossalai/gemini/chunk/search_utils.py
colossalai/gemini/chunk/search_utils.py
+11
-6
colossalai/gemini/chunk/utils.py
colossalai/gemini/chunk/utils.py
+3
-2
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+18
-1
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+6
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+6
-2
colossalai/gemini/memory_tracer/param_runtime_order.py
colossalai/gemini/memory_tracer/param_runtime_order.py
+3
-0
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+14
-3
colossalai/nn/parallel/gemini_parallel.py
colossalai/nn/parallel/gemini_parallel.py
+5
-2
tests/test_gemini/update/test_gemini_use_rmt.py
tests/test_gemini/update/test_gemini_use_rmt.py
+11
-12
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+0
-1
No files found.
colossalai/gemini/chunk/search_utils.py
View file @
9214d1fe
import
math
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch.nn
as
nn
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerator
from
colossalai.gemini.memory_tracer
import
MemStats
,
OrderedParamGenerator
from
colossalai.tensor
import
ColoParameter
...
...
@@ -73,7 +73,8 @@ def search_chunk_configuration(
search_range_mb
:
float
,
search_interval_byte
:
int
,
# hidden size is the best value for the interval
min_chunk_size_mb
:
float
=
32
,
filter_exlarge_params
:
bool
=
True
)
->
Tuple
[
Dict
,
int
]:
filter_exlarge_params
:
bool
=
True
,
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
]:
"""search_chunk_configuration
Args:
...
...
@@ -86,9 +87,13 @@ def search_chunk_configuration(
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
"""
param_order
=
OrderedParamGenerator
()
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
if
memstas
is
not
None
:
param_order
=
memstas
.
param_order
()
else
:
# build the param visited order right now
param_order
=
OrderedParamGenerator
()
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
search_range_byte
=
round
(
search_range_mb
*
1024
**
2
)
min_chunk_size_byte
=
round
(
min_chunk_size_mb
*
1024
**
2
)
...
...
colossalai/gemini/chunk/utils.py
View file @
9214d1fe
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.chunk.search_utils
import
in_ddp
,
search_chunk_configuration
from
colossalai.gemini.memory_tracer
import
MemStats
def
init_chunk_manager
(
model
:
nn
.
Module
,
...
...
@@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module,
total_size
=
sum
(
params_sizes
)
/
1024
**
2
dist
.
barrier
()
begin
e
=
time
()
begin
=
time
()
config_dict
,
wasted_size
=
search_chunk_configuration
(
model
,
**
kwargs_dict
)
dist
.
barrier
()
end
=
time
()
span_s
=
end
-
begin
e
span_s
=
end
-
begin
wasted_size
/=
1024
**
2
if
dist
.
get_rank
()
==
0
:
...
...
colossalai/gemini/gemini_mgr.py
View file @
9214d1fe
...
...
@@ -25,6 +25,7 @@ class GeminiManager:
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
...
...
@@ -33,8 +34,11 @@ class GeminiManager:
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
self
.
_premade_memstats_
=
memstats
is
not
None
self
.
_memstats
=
memstats
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
,
memstats
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_
memstats
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_idx
:
int
=
-
1
...
...
@@ -46,6 +50,19 @@ class GeminiManager:
self
.
_warmup
=
True
self
.
_comp_cuda_demand_time
=
0
def
memstats
(
self
):
"""memstats
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if
self
.
_premade_memstats_
:
return
self
.
_memstats
else
:
assert
not
self
.
_warmup
,
"Gemini Manager has memstats after warm up! Now is during warmup."
return
self
.
_mem_stats_collector
.
_memstats
def
pre_iter
(
self
,
*
args
):
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
self
.
_mem_stats_collector
.
start_collection
()
...
...
colossalai/gemini/memory_tracer/memory_stats.py
View file @
9214d1fe
...
...
@@ -23,6 +23,12 @@ class MemStats(object):
self
.
_param_runtime_order
=
OrderedParamGenerator
()
def
param_order
(
self
):
if
self
.
_param_runtime_order
.
is_empty
():
raise
RuntimeError
else
:
return
self
.
_param_runtime_order
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
self
.
_overall_cuda_list
.
append
(
val
)
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
9214d1fe
...
...
@@ -37,7 +37,7 @@ class MemStatsCollector:
self
.
_memstats
=
MemStats
()
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
"""
Get max
non model data memory usage
of current sampling period
"""
Maximum
non model data memory usage
during the next Op run
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
...
...
@@ -47,6 +47,9 @@ class MemStatsCollector:
"""
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
assert
len
(
self
.
_memstats
.
non_model_data_list
(
device_type
))
>
self
.
_step_idx
,
\
f
"
{
len
(
self
.
_memstats
.
non_model_data_list
(
device_type
))
}
should be > than step idx
{
self
.
_step_idx
}
, "
\
f
"step total
{
self
.
_step_total
}
"
next_non_model_data
=
self
.
_memstats
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
return
next_non_model_data
...
...
@@ -61,7 +64,8 @@ class MemStatsCollector:
def
finish_collection
(
self
):
self
.
sample_overall_data
()
self
.
_step_total
=
len
(
self
.
_sampling_time
)
# self._step_total = len(self._sampling_time)
self
.
_step_total
=
len
(
self
.
_memstats
.
non_model_data_list
(
'cuda'
))
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
...
...
colossalai/gemini/memory_tracer/param_runtime_order.py
View file @
9214d1fe
...
...
@@ -35,5 +35,8 @@ class OrderedParamGenerator(ParamGenerator):
visited_set
.
add
(
p
)
del
visited_set
def
is_empty
(
self
):
return
len
(
self
.
param_visited_order
)
>
0
def
clear
(
self
):
self
.
param_visited_order
=
[]
colossalai/nn/parallel/data_parallel.py
View file @
9214d1fe
...
...
@@ -8,6 +8,7 @@ import torch.distributed as dist
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
,
TensorState
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerator
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.parallel.utils
import
get_temp_total_chunk_on_cuda
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
...
...
@@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP):
self
.
grads_device
:
Dict
[
torch
.
Tensor
,
torch
.
device
]
=
{}
cpu_offload
=
self
.
gemini_manager
.
policy_name
!=
'cuda'
# TODO: get param order and filter unused params
for
p
in
module
.
parameters
():
if
self
.
gemini_manager
.
_premade_memstats_
:
# build chunk in param runtime visited order.
param_order
=
self
.
gemini_manager
.
memstats
().
_param_runtime_order
else
:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order
=
OrderedParamGenerator
()
for
p
in
module
.
parameters
():
param_order
.
append
(
p
)
for
p
in
param_order
.
generate
():
assert
isinstance
(
p
,
ColoParameter
)
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
...
...
@@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP):
self
.
chunk_manager
.
close_all_groups
()
self
.
_cast_buffers
()
params_list
=
[
p
for
p
in
module
.
parameters
()
if
not
getattr
(
p
,
'_ddp_to_ignore'
,
False
)]
params_list
=
[
p
for
p
in
param_order
.
generate
()
if
not
getattr
(
p
,
'_ddp_to_ignore'
,
False
)]
for
p
,
fp32_p
in
zip
(
params_list
,
self
.
fp32_params
):
chunk_16
=
self
.
chunk_manager
.
get_chunk
(
p
)
chunk_32
=
self
.
chunk_manager
.
get_chunk
(
fp32_p
)
...
...
colossalai/nn/parallel/gemini_parallel.py
View file @
9214d1fe
...
...
@@ -4,6 +4,7 @@ import torch
from
colossalai.gemini.chunk
import
init_chunk_manager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
.data_parallel
import
ZeroDDP
...
...
@@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32
:
bool
=
False
,
search_range_mb
:
int
=
32
,
hidden_dim
:
Optional
[
int
]
=
None
,
min_chunk_size_mb
:
Optional
[
float
]
=
None
)
->
None
:
min_chunk_size_mb
:
Optional
[
float
]
=
None
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
...
...
@@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP):
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
chunk_manager
=
init_chunk_manager
(
model
=
module
,
init_device
=
device
,
hidden_dim
=
hidden_dim
,
search_range_mb
=
search_range_mb
,
min_chunk_size_mb
=
min_chunk_size_mb
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
super
().
__init__
(
module
,
gemini_manager
,
pin_memory
,
force_outputs_fp32
)
tests/test_gemini/update/test_gemini_use_rmt.py
View file @
9214d1fe
...
...
@@ -8,7 +8,8 @@ import colossalai
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
GeminiDDP
,
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
...
...
@@ -44,29 +45,27 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
run_fwd_bwd
(
runtime_mem_tracer
,
input_ids
,
label
,
criterion
,
runtime_mem_tracer
)
memstats
=
runtime_mem_tracer
.
memstats
()
runtime_tracer_non_model_data
=
runtime_mem_tracer
.
_memstats
.
_non_model_data_cuda_list
print
(
'runtime tracer: '
,
runtime_tracer_non_model_data
)
print
(
'runtime tracer
non model data points
: '
,
len
(
runtime_tracer_non_model_data
)
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
=
GeminiDDP
(
model
,
device
=
'cuda'
,
placement_policy
=
placement_policy
,
search_range_mb
=
1
,
memstats
=
memstats
)
zero_optim
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
1
)
pg
=
ProcessGroup
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if
i
>
1
:
# print(f'iteration {i}')
if
i
>
4
:
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
zero_optim
.
zero_grad
()
set_seed
(
42
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
model
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
zero_optim
.
step
()
gemini_non_model_data
=
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
gemini_non_model_data
=
model
.
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
# print('gemini non model data:', gemini_non_model_data)
...
...
tests/test_gemini/update/test_optim.py
View file @
9214d1fe
from
functools
import
partial
from
time
import
time
import
pytest
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment