Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6630d455
Unverified
Commit
6630d455
authored
Nov 17, 2022
by
Genghan Zhang
Committed by
GitHub
Nov 17, 2022
Browse files
[autoparallel] Add alpha beta (#1973)
* Add alpha beta * Fix test * Fix test
parent
cc0ed7cf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
265 additions
and
0 deletions
+265
-0
colossalai/device/__init__.py
colossalai/device/__init__.py
+4
-0
colossalai/device/calc_pipeline_strategy.py
colossalai/device/calc_pipeline_strategy.py
+127
-0
colossalai/device/profile_alpha_beta.py
colossalai/device/profile_alpha_beta.py
+120
-0
tests/test_device/test_alpha_beta.py
tests/test_device/test_alpha_beta.py
+14
-0
No files found.
colossalai/device/__init__.py
View file @
6630d455
from
.calc_pipeline_strategy
import
alpa_dp
from
.profile_alpha_beta
import
profile_alpha_beta
__all__
=
[
'profile_alpha_beta'
,
'alpa_dp'
]
colossalai/device/calc_pipeline_strategy.py
0 → 100644
View file @
6630d455
from
math
import
pow
import
numpy
as
np
def
get_submesh_choices
(
num_hosts
,
num_devices_per_host
,
mode
=
"new"
):
submesh_choices
=
[]
i
=
1
p
=
-
1
while
i
<=
num_devices_per_host
:
i
*=
2
p
+=
1
assert
pow
(
2
,
p
)
==
num_devices_per_host
,
(
"Only supports the cases where num_devices_per_host is power of two, "
f
"while now num_devices_per_host =
{
num_devices_per_host
}
"
)
if
mode
==
"alpa"
:
for
i
in
range
(
p
+
1
):
submesh_choices
.
append
((
1
,
pow
(
2
,
i
)))
for
i
in
range
(
2
,
num_hosts
+
1
):
submesh_choices
.
append
((
i
,
num_devices_per_host
))
elif
mode
==
"new"
:
for
i
in
range
(
p
//
2
+
1
):
for
j
in
range
(
i
,
p
-
i
+
1
):
submesh_choices
.
append
((
pow
(
2
,
i
),
pow
(
2
,
j
)))
return
submesh_choices
def
alpa_dp_impl
(
num_layers
,
num_devices
,
num_microbatches
,
submesh_choices
,
compute_cost
,
max_stage_cost
,
best_configs
):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f
=
np
.
full
((
num_layers
+
1
,
num_layers
+
1
,
num_devices
+
1
),
np
.
inf
,
dtype
=
np
.
float32
)
f_stage_max
=
np
.
full
((
num_layers
+
1
,
num_layers
+
1
,
num_devices
+
1
),
0.0
,
dtype
=
np
.
float32
)
f_argmin
=
np
.
full
((
num_layers
+
1
,
num_layers
+
1
,
num_devices
+
1
,
3
),
-
1
,
dtype
=
np
.
int32
)
f
[
0
,
num_layers
,
0
]
=
0
for
s
in
range
(
1
,
num_layers
+
1
):
for
k
in
range
(
num_layers
-
1
,
-
1
,
-
1
):
for
d
in
range
(
1
,
num_devices
+
1
):
for
m
,
submesh
in
enumerate
(
submesh_choices
):
n_submesh_devices
=
np
.
prod
(
np
.
array
(
submesh
))
if
n_submesh_devices
<=
d
:
# TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete.
# if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]:
# ...
for
i
in
range
(
num_layers
,
k
,
-
1
):
stage_cost
=
compute_cost
[
k
,
i
,
m
]
new_cost
=
f
[
s
-
1
,
k
,
d
-
n_submesh_devices
]
+
stage_cost
if
(
stage_cost
<=
max_stage_cost
and
new_cost
<
f
[
s
,
k
,
d
]):
f
[
s
,
k
,
d
]
=
new_cost
f_stage_max
[
s
,
k
,
d
]
=
max
(
stage_cost
,
f_stage_max
[
s
-
1
,
i
,
d
-
n_submesh_devices
])
f_argmin
[
s
,
k
,
d
]
=
(
i
,
m
,
best_configs
[
k
,
i
,
m
])
best_s
=
-
1
best_total_cost
=
np
.
inf
for
s
in
range
(
1
,
num_layers
+
1
):
if
f
[
s
,
0
,
num_devices
]
<
best_total_cost
:
best_s
=
s
best_total_cost
=
f
[
s
,
0
,
num_devices
]
if
np
.
isinf
(
best_total_cost
):
return
np
.
inf
,
None
total_cost
=
f
[
best_s
,
0
,
num_devices
]
+
(
num_microbatches
-
1
)
*
f_stage_max
[
best_s
,
0
,
num_devices
]
current_s
=
best_s
current_layer
=
0
current_devices
=
num_devices
res
=
[]
while
current_s
>
0
and
current_layer
<
num_layers
and
current_devices
>
0
:
next_start_layer
,
submesh_choice
,
autosharding_choice
=
(
f_argmin
[
current_s
,
current_layer
,
current_devices
])
assert
next_start_layer
!=
-
1
and
current_devices
!=
-
1
res
.
append
(((
current_layer
,
next_start_layer
),
submesh_choice
,
autosharding_choice
))
current_s
-=
1
current_layer
=
next_start_layer
current_devices
-=
np
.
prod
(
np
.
array
(
submesh_choices
[
submesh_choice
]))
assert
(
current_s
==
0
and
current_layer
==
num_layers
and
current_devices
==
0
)
return
total_cost
,
res
def
alpa_dp
(
num_layers
,
num_devices
,
num_microbatches
,
submesh_choices
,
num_autosharding_configs
,
compute_cost
,
gap
=
1e-6
):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert
np
.
shape
(
compute_cost
)
==
(
num_layers
,
num_layers
,
len
(
submesh_choices
),
num_autosharding_configs
),
"Cost shape wrong."
all_possible_stage_costs
=
np
.
sort
(
np
.
unique
(
compute_cost
))
best_cost
=
np
.
inf
best_solution
=
None
last_max_stage_cost
=
0.0
# TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost?
# In dp_impl it seems the argmin n_config will be chosen. Just amin here.
best_configs
=
np
.
argmin
(
compute_cost
,
axis
=
3
)
best_compute_cost
=
np
.
amin
(
compute_cost
,
axis
=
3
)
assert
len
(
all_possible_stage_costs
),
"no solution in auto stage construction."
for
max_stage_cost
in
all_possible_stage_costs
:
if
max_stage_cost
*
num_microbatches
>=
best_cost
:
break
if
max_stage_cost
-
last_max_stage_cost
<
gap
:
continue
cost
,
solution
=
alpa_dp_impl
(
num_layers
,
num_devices
,
num_microbatches
,
submesh_choices
,
best_compute_cost
,
max_stage_cost
,
best_configs
)
if
cost
<
best_cost
:
best_cost
=
cost
best_solution
=
solution
last_max_stage_cost
=
max_stage_cost
return
best_cost
,
best_solution
colossalai/device/profile_alpha_beta.py
0 → 100644
View file @
6630d455
import
fcntl
import
math
import
os
import
time
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
MB
=
int
((
1
<<
10
)
*
1e3
)
GB
=
int
((
1
<<
20
)
*
1e3
)
Byte
=
4
FRAMEWORK
=
0
NON_SENSE
=
(
0.1
,
0.1
)
def
printflock
(
*
msgs
):
""" solves multi-process interleaved print problem """
with
open
(
__file__
,
"r"
)
as
fh
:
fcntl
.
flock
(
fh
,
fcntl
.
LOCK_EX
)
try
:
print
(
*
msgs
)
finally
:
fcntl
.
flock
(
fh
,
fcntl
.
LOCK_UN
)
def
profile
(
device1d
,
nbytes
,
ctype
):
warmup
=
5
repeat
=
25
rank
=
dist
.
get_rank
()
src_device_num
=
device1d
[
0
]
wsize
=
len
(
device1d
)
group
=
dist
.
new_group
(
device1d
)
torch
.
cuda
.
set_device
(
rank
)
device
=
torch
.
device
(
"cuda"
,
rank
)
buf
=
torch
.
randn
(
nbytes
//
4
).
to
(
device
)
torch
.
cuda
.
synchronize
()
# warmup
for
_
in
range
(
warmup
):
if
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
elif
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
group
)
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
begin
=
time
.
perf_counter
()
for
_
in
range
(
repeat
):
if
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
elif
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
group
)
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
dist
.
barrier
()
if
rank
==
src_device_num
:
avg_time_s
=
(
end
-
begin
)
/
repeat
-
FRAMEWORK
alg_band
=
nbytes
/
avg_time_s
if
ctype
==
"b"
:
bus_band
=
alg_band
elif
ctype
==
"a"
:
bus_band
=
2
*
(
wsize
-
1
)
/
wsize
*
alg_band
print
(
f
"GPU:
{
rank
}
, Bytes:
{
nbytes
}
B,Time:
{
round
(
avg_time_s
*
1e6
,
2
)
}
us, Bus bandwidth:
{
round
(
bus_band
/
GB
,
2
)
}
GB/s"
)
return
(
avg_time_s
,
alg_band
)
else
:
return
NON_SENSE
# Just a placeholder
def
profile_latency
(
device1d
,
it
=
3
,
ctype
=
"a"
):
latency
=
[]
for
i
in
range
(
it
):
nbytes
=
int
(
Byte
<<
i
)
(
t
,
_
)
=
profile
(
device1d
,
nbytes
,
ctype
)
latency
.
append
(
t
)
return
min
(
latency
)
def
profile_bandwidth
(
device1d
,
maxbytes
,
ctype
=
"a"
):
(
_
,
bandwidth
)
=
profile
(
device1d
,
maxbytes
,
ctype
)
return
bandwidth
def
profile_ab
(
rank
,
*
args
):
wsize
=
int
(
torch
.
cuda
.
device_count
())
device1d
=
args
[
0
]
return_dict
=
args
[
1
]
ctype
=
args
[
2
]
os
.
environ
[
'MASTER_ADDR'
]
=
'localhost'
os
.
environ
[
'MASTER_PORT'
]
=
'29020'
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
NCCL
,
init_method
=
'env://'
,
world_size
=
wsize
,
rank
=
rank
)
device
=
torch
.
device
(
"cuda"
,
rank
)
max_nbytes
=
torch
.
tensor
(
torch
.
cuda
.
mem_get_info
(
device
)[
0
]).
to
(
device
)
max_nbytes
=
min
(
int
(
4
*
GB
),
int
(
GB
<<
int
(
math
.
log2
(
max_nbytes
.
item
()
/
GB
))))
if
rank
==
device1d
[
0
]:
print
(
f
"max_nbytes:
{
max_nbytes
}
B"
)
alpha
=
profile_latency
(
device1d
,
it
=
5
,
ctype
=
ctype
)
beta
=
1
/
profile_bandwidth
(
device1d
,
maxbytes
=
max_nbytes
,
ctype
=
ctype
)
if
rank
==
device1d
[
0
]:
print
(
f
"alpha(us):
{
round
(
alpha
*
1e6
,
2
)
}
, beta(us/GB):
{
round
(
beta
*
1e6
*
GB
,
2
)
}
"
)
return_dict
[
rank
]
=
(
alpha
,
beta
)
def
profile_alpha_beta
(
device1d
):
assert
torch
.
cuda
.
is_available
()
assert
len
(
device1d
)
>
0
and
len
(
device1d
)
<=
int
(
torch
.
cuda
.
device_count
())
manager
=
mp
.
Manager
()
return_dict
=
manager
.
dict
()
ctype
=
"a"
mp
.
spawn
(
profile_ab
,
args
=
[
device1d
,
return_dict
,
ctype
],
nprocs
=
int
(
torch
.
cuda
.
device_count
()))
return
return_dict
[
device1d
[
0
]]
tests/test_device/test_alpha_beta.py
0 → 100644
View file @
6630d455
import
pytest
from
colossalai.device
import
profile_alpha_beta
@
pytest
.
mark
.
skip
(
reason
=
"Skip because assertion fails for CI devices"
)
def
test_profile_alpha_beta
():
physical_devices
=
[
0
,
1
,
2
,
3
]
(
alpha
,
beta
)
=
profile_alpha_beta
(
physical_devices
)
assert
alpha
>
0
and
alpha
<
1e-4
and
beta
>
0
and
beta
<
1e-10
if
__name__
==
'__main__'
:
test_profile_alpha_beta
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment