Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9056677b
Unverified
Commit
9056677b
authored
Aug 11, 2022
by
HELSON
Committed by
GitHub
Aug 11, 2022
Browse files
[zero] add chunk size searching algorithm for parameters in different groups (#1436)
parent
c9427a32
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
164 additions
and
0 deletions
+164
-0
colossalai/gemini/update/__init__.py
colossalai/gemini/update/__init__.py
+1
-0
colossalai/gemini/update/search_utils.py
colossalai/gemini/update/search_utils.py
+96
-0
tests/test_gemini/update/test_search.py
tests/test_gemini/update/test_search.py
+67
-0
No files found.
colossalai/gemini/update/__init__.py
View file @
9056677b
from
.chunkv2
import
ChunkV2
from
.search_utils
import
clasify_params
,
search_chunk_configuration
colossalai/gemini/update/search_utils.py
0 → 100644
View file @
9056677b
from
typing
import
Dict
,
List
import
numpy
as
np
import
torch.nn
as
nn
from
colossalai.tensor
import
ColoParameter
def
_filter_exlarge_params
(
model
:
nn
.
Module
,
size_dict
:
Dict
[
int
,
List
[
int
]])
->
None
:
"""Filter those parameters whose size is too large from others.
"""
params_size
=
[
p
.
numel
()
for
p
in
model
.
parameters
()]
params_size_arr
=
np
.
array
(
params_size
)
std
=
np
.
std
(
params_size_arr
)
mean
=
np
.
mean
(
params_size_arr
)
upper_limit
=
mean
+
3
*
std
for
key
in
size_dict
:
org_list
=
size_dict
[
key
]
size_dict
[
key
]
=
list
(
filter
(
lambda
x
:
x
<=
upper_limit
,
org_list
))
def
_get_unused_byte
(
size_list
:
List
[
int
],
chunk_size
:
int
)
->
int
:
"""Get unused byte for a certain chunk size.
"""
acc
=
0
left
=
0
for
s
in
size_list
:
if
s
>
left
:
acc
+=
left
left
=
chunk_size
left
-=
s
return
left
+
acc
def
clasify_params
(
model
:
nn
.
Module
)
->
Dict
[
int
,
List
[
ColoParameter
]]:
params_dict
:
Dict
[
int
,
List
[
ColoParameter
]]
=
dict
()
for
param
in
model
.
parameters
():
assert
isinstance
(
param
,
ColoParameter
),
"please init model in the ColoInitContext"
param_key
=
param
.
process_group
.
dp_world_size
()
if
param_key
not
in
params_dict
:
params_dict
[
param_key
]
=
[]
params_dict
[
param_key
].
append
(
param
)
return
params_dict
def
search_chunk_configuration
(
model
:
nn
.
Module
,
search_range_mb
:
int
,
search_interval_byte
:
int
,
# hidden size is the best value for the interval
min_chunk_size_mb
:
int
=
32
,
filter_exlarge_params
:
bool
=
True
):
search_range_byte
=
search_range_mb
*
1024
**
2
min_chunk_size_byte
=
min_chunk_size_mb
*
1024
**
2
assert
search_range_byte
%
search_interval_byte
==
0
params_dict
=
clasify_params
(
model
)
config_dict
:
Dict
[
int
,
Dict
]
=
dict
()
size_dict
:
Dict
[
int
,
List
[
int
]]
=
dict
()
for
key
in
params_dict
:
params_list
=
params_dict
[
key
]
size_list
=
[
p
.
numel
()
for
p
in
params_list
]
# let small parameters keep gathered in CUDA all the time
total_size
=
sum
(
size_list
)
if
total_size
<
min_chunk_size_byte
:
config_dict
[
key
]
=
dict
(
chunk_size
=
total_size
,
keep_gathered
=
True
)
else
:
size_dict
[
key
]
=
size_list
if
filter_exlarge_params
:
_filter_exlarge_params
(
model
,
size_dict
)
max_size
=
min_chunk_size_byte
for
key
in
size_dict
:
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
min_chunk_waste
=
float
(
'+inf'
)
best_chunk_size
=
max_size
for
chunk_size
in
range
(
max_size
,
max_size
+
search_range_byte
+
1
,
search_interval_byte
):
temp_waste
=
0
for
key
in
size_dict
:
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
if
temp_waste
<
min_chunk_waste
:
min_chunk_waste
=
temp_waste
best_chunk_size
=
chunk_size
for
key
in
params_dict
:
if
key
in
config_dict
:
continue
config_dict
[
key
]
=
dict
(
chunk_size
=
best_chunk_size
,
keep_gathered
=
False
)
return
config_dict
tests/test_gemini/update/test_search.py
0 → 100644
View file @
9056677b
import
pytest
from
functools
import
partial
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.gemini.update
import
search_chunk_configuration
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_process_group
(
pg
)
p
.
set_tensor_spec
(
*
tensor_spec
)
def
exam_search_chunk_size
():
world_size
=
torch
.
distributed
.
get_world_size
()
pg_tp
=
ProcessGroup
(
tp_degree
=
world_size
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# make sure torch_model and model has the same parameter values
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
init_1d_row_spec
(
model
,
pg_tp
)
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
16
,
min_chunk_size_mb
=
0
,
filter_exlarge_params
=
True
)
for
key
in
config_dict
:
chunk_size
=
config_dict
[
key
][
'chunk_size'
]
if
world_size
==
1
:
assert
chunk_size
==
31616
else
:
assert
chunk_size
==
1024
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_search_chunk_size
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_search
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_search
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment