Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4e960396
Unverified
Commit
4e960396
authored
Jan 07, 2023
by
Jiarui Fang
Committed by
GitHub
Jan 07, 2023
Browse files
[device] find best logical mesh
parents
8f72b6f8
b5a3a4a6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
265 additions
and
3 deletions
+265
-3
colossalai/device/alpha_beta_profiler.py
colossalai/device/alpha_beta_profiler.py
+190
-3
tests/test_device/test_extract_alpha_beta.py
tests/test_device/test_extract_alpha_beta.py
+39
-0
tests/test_device/test_search_logical_device_mesh.py
tests/test_device/test_search_logical_device_mesh.py
+36
-0
No files found.
colossalai/device/alpha_beta_profiler.py
View file @
4e960396
...
...
@@ -21,7 +21,7 @@ class AlphaBetaProfiler:
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.
profile_ab()
>>> ab_dict = profiler.
alpha_beta_dict
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
...
...
@@ -31,13 +31,16 @@ class AlphaBetaProfiler:
def
__init__
(
self
,
physical_devices
:
List
[
int
],
alpha_beta_dict
:
Dict
[
Tuple
[
int
,
int
],
Tuple
[
float
,
float
]]
=
None
,
ctype
:
str
=
'a'
,
warmup
:
int
=
5
,
repeat
:
int
=
25
,
latency_iters
:
int
=
5
):
latency_iters
:
int
=
5
,
homogeneous_tolerance
:
float
=
0.1
):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
...
...
@@ -49,8 +52,13 @@ class AlphaBetaProfiler:
self
.
warmup
=
warmup
self
.
repeat
=
repeat
self
.
latency_iters
=
latency_iters
self
.
homogeneous_tolerance
=
homogeneous_tolerance
self
.
process_group_dict
=
None
self
.
_init_profiling
()
if
alpha_beta_dict
is
None
:
self
.
alpha_beta_dict
=
self
.
profile_ab
()
else
:
self
.
alpha_beta_dict
=
alpha_beta_dict
def
_init_profiling
(
self
):
# Create process group list based on its global rank
...
...
@@ -139,7 +147,7 @@ class AlphaBetaProfiler:
return
latency
def
profile_bandwidth
(
self
,
process_group
,
pg_handler
,
maxbytes
):
def
profile_bandwidth
(
self
,
process_group
,
pg_handler
,
maxbytes
=
(
1
*
GB
)
):
'''
This function is used to profile the bandwidth of the given process group.
...
...
@@ -159,6 +167,7 @@ class AlphaBetaProfiler:
'''
alpha_beta_dict
:
Dict
[
Tuple
[
int
],
Tuple
[
float
]]
=
{}
rank
=
dist
.
get_rank
()
global_pg_handler
=
dist
.
new_group
(
self
.
physical_devices
)
def
get_max_nbytes
(
process_group
:
Tuple
[
int
],
pg_handler
:
dist
.
ProcessGroup
):
assert
rank
in
process_group
...
...
@@ -197,3 +206,181 @@ class AlphaBetaProfiler:
alpha_beta_dict
.
update
(
symmetry_ab_dict
)
return
alpha_beta_dict
def
search_best_logical_mesh
(
self
):
'''
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
are homogeneous if the beta value is close enough.
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
device group means the lowest beta value in the groups which contains all the physical devices.
And the reason we require the group contains all the physical devices is that the devices not in
the group will decrease the bandwidth of the group.
3. If the best homogeneous device group is found, we will construct the largest ring for each device
based on the best homogeneous device group, and the best logical mesh will be the union of all the
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
4 devices.
Returns:
best_logical_mesh: The best logical mesh for the given device list.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
def
_power_of_two
(
integer
):
return
integer
&
(
integer
-
1
)
==
0
def
_detect_homogeneous_device
(
alpha_beta_dict
):
'''
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
homogeneous_device_dict
:
Dict
[
float
,
List
[
Tuple
[
int
]]]
=
{}
for
process_group
,
(
_
,
beta
)
in
alpha_beta_dict
.
items
():
if
homogeneous_device_dict
is
None
:
homogeneous_device_dict
[
beta
]
=
[]
homogeneous_device_dict
[
beta
].
append
(
process_group
)
match_beta
=
None
for
beta_value
in
homogeneous_device_dict
.
keys
():
if
beta
<=
beta_value
*
(
1
+
self
.
homogeneous_tolerance
)
and
beta
>=
beta_value
*
(
1
-
self
.
homogeneous_tolerance
):
match_beta
=
beta_value
break
if
match_beta
is
not
None
:
homogeneous_device_dict
[
match_beta
].
append
(
process_group
)
else
:
homogeneous_device_dict
[
beta
]
=
[]
homogeneous_device_dict
[
beta
].
append
(
process_group
)
return
homogeneous_device_dict
def
_check_contain_all_devices
(
homogeneous_group
:
List
[
Tuple
[
int
]]):
'''
This function is used to check whether the homogeneous_group contains all physical devices.
'''
flatten_mesh
=
[]
for
process_group
in
homogeneous_group
:
flatten_mesh
.
extend
(
process_group
)
non_duplicated_flatten_mesh
=
set
(
flatten_mesh
)
return
len
(
non_duplicated_flatten_mesh
)
==
len
(
self
.
physical_devices
)
def
_construct_largest_ring
(
homogeneous_group
:
List
[
Tuple
[
int
]]):
'''
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
# Construct the ring
ring
=
[]
ranks_in_ring
=
[]
for
rank
in
self
.
physical_devices
:
if
rank
in
ranks_in_ring
:
continue
stable_status
=
False
ring_for_rank
=
[]
ring_for_rank
.
append
(
rank
)
check_rank_list
=
[
rank
]
rank_to_check_list
=
[]
while
not
stable_status
:
stable_status
=
True
check_rank_list
.
extend
(
rank_to_check_list
)
rank_to_check_list
=
[]
for
i
in
range
(
len
(
check_rank_list
)):
check_rank
=
check_rank_list
.
pop
()
for
process_group
in
homogeneous_group
:
if
check_rank
in
process_group
:
rank_to_append
=
process_group
[
0
]
if
process_group
[
1
]
==
check_rank
else
process_group
[
1
]
if
rank_to_append
not
in
ring_for_rank
:
stable_status
=
False
rank_to_check_list
.
append
(
rank_to_append
)
ring_for_rank
.
append
(
rank_to_append
)
ring
.
append
(
ring_for_rank
)
ranks_in_ring
.
extend
(
ring_for_rank
)
return
ring
assert
_power_of_two
(
self
.
world_size
)
power_of_two
=
int
(
math
.
log2
(
self
.
world_size
))
median
=
power_of_two
//
2
balanced_logical_mesh_shape
=
(
2
**
median
,
2
**
(
power_of_two
-
median
))
row_size
,
column_size
=
balanced_logical_mesh_shape
[
0
],
balanced_logical_mesh_shape
[
1
]
balanced_logical_mesh
=
[]
for
row_index
in
range
(
row_size
):
balanced_logical_mesh
.
append
([])
for
column_index
in
range
(
column_size
):
balanced_logical_mesh
[
row_index
].
append
(
self
.
physical_devices
[
row_index
*
column_size
+
column_index
])
homogeneous_device_dict
=
_detect_homogeneous_device
(
self
.
alpha_beta_dict
)
beta_list
=
[
b
for
b
in
homogeneous_device_dict
.
keys
()]
beta_list
.
sort
()
beta_list
.
reverse
()
homogeneous_types
=
len
(
beta_list
)
best_logical_mesh
=
None
if
homogeneous_types
>=
2
:
for
_
in
range
(
homogeneous_types
-
1
):
lowest_beta
=
beta_list
.
pop
()
best_homogeneous_group
=
homogeneous_device_dict
[
lowest_beta
]
# if the best homogeneous group contains all physical devices,
# we will build the logical device mesh based on it. Otherwise,
# we will check next level homogeneous group.
if
_check_contain_all_devices
(
best_homogeneous_group
):
# We choose the largest ring for each rank to maximum the best bus utilization.
best_logical_mesh
=
_construct_largest_ring
(
best_homogeneous_group
)
break
if
homogeneous_types
==
1
or
best_logical_mesh
is
None
:
# in this case, we use balanced logical mesh as the best
# logical mesh.
best_logical_mesh
=
balanced_logical_mesh
return
best_logical_mesh
def
extract_alpha_beta_for_device_mesh
(
self
):
'''
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
>>> print(mesh_alpha)
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
best_logical_mesh
=
self
.
search_best_logical_mesh
()
first_axis
=
[
row
[
0
]
for
row
in
best_logical_mesh
]
second_axis
=
best_logical_mesh
[
0
]
# init process group for both axes
first_axis_process_group
=
dist
.
new_group
(
first_axis
)
second_axis_process_group
=
dist
.
new_group
(
second_axis
)
# extract alpha and beta for both axes
def
_extract_alpha_beta
(
pg
,
pg_handler
):
latency
=
self
.
profile_latency
(
pg
,
pg_handler
)
bandwidth
=
self
.
profile_bandwidth
(
pg
,
pg_handler
)
broadcast_object
=
[
latency
,
bandwidth
]
dist
.
broadcast_object_list
(
broadcast_object
,
src
=
pg
[
0
])
return
broadcast_object
first_latency
,
first_bandwidth
=
_extract_alpha_beta
(
first_axis
,
first_axis_process_group
)
second_latency
,
second_bandwidth
=
_extract_alpha_beta
(
second_axis
,
second_axis_process_group
)
mesh_alpha
=
[
first_latency
,
second_latency
]
mesh_beta
=
[
1
/
first_bandwidth
,
1
/
second_bandwidth
]
return
mesh_alpha
,
mesh_beta
tests/test_device/test_extract_alpha_beta.py
0 → 100644
View file @
4e960396
from
functools
import
partial
import
pytest
import
torch.multiprocessing
as
mp
from
colossalai.device
import
AlphaBetaProfiler
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
def
check_extract_alpha_beta
(
rank
,
physical_devices
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
profiler
=
AlphaBetaProfiler
(
physical_devices
)
mesh_alpha
,
mesh_beta
=
profiler
.
extract_alpha_beta_for_device_mesh
()
for
alpha
in
mesh_alpha
:
assert
alpha
>
0
and
alpha
<
1e-3
for
beta
in
mesh_beta
:
assert
beta
>
0
and
beta
<
1e-10
@
pytest
.
mark
.
skip
(
reason
=
"Skip because assertion may fail for CI devices"
)
@
pytest
.
mark
.
dist
@
parameterize
(
'physical_devices'
,
[[
0
,
1
,
2
,
3
],
[
0
,
3
]])
@
rerun_if_address_is_in_use
()
def
test_profile_alpha_beta
(
physical_devices
):
world_size
=
4
run_func
=
partial
(
check_extract_alpha_beta
,
physical_devices
=
physical_devices
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_profile_alpha_beta
()
tests/test_device/test_search_logical_device_mesh.py
0 → 100644
View file @
4e960396
from
functools
import
partial
import
pytest
import
torch.multiprocessing
as
mp
from
colossalai.device
import
AlphaBetaProfiler
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
def
check_alpha_beta
(
rank
,
physical_devices
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
profiler
=
AlphaBetaProfiler
(
physical_devices
)
best_logical_mesh
=
profiler
.
search_best_logical_mesh
()
if
physical_devices
==
[
0
,
1
,
2
,
3
]:
assert
best_logical_mesh
==
[[
0
,
1
],
[
2
,
3
]]
elif
physical_devices
==
[
0
,
3
]:
assert
best_logical_mesh
==
[[
0
,
3
]]
@
pytest
.
mark
.
skip
(
reason
=
"Skip because assertion may fail for CI devices"
)
@
pytest
.
mark
.
dist
@
parameterize
(
'physical_devices'
,
[[
0
,
1
,
2
,
3
],
[
0
,
3
]])
@
rerun_if_address_is_in_use
()
def
test_profile_alpha_beta
(
physical_devices
):
world_size
=
4
run_func
=
partial
(
check_alpha_beta
,
physical_devices
=
physical_devices
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_profile_alpha_beta
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment