Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9c9246c0
Unverified
Commit
9c9246c0
authored
Jan 05, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 05, 2023
Browse files
[device] alpha beta profiler (#2311)
* [device] alpha beta profiler * add usage * fix variable name
parent
f1bc2418
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
226 additions
and
128 deletions
+226
-128
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+1
-1
colossalai/device/__init__.py
colossalai/device/__init__.py
+2
-2
colossalai/device/alpha_beta_profiler.py
colossalai/device/alpha_beta_profiler.py
+199
-0
colossalai/device/profile_alpha_beta.py
colossalai/device/profile_alpha_beta.py
+0
-120
tests/test_device/test_alpha_beta.py
tests/test_device/test_alpha_beta.py
+24
-5
No files found.
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
9c9246c0
...
@@ -16,8 +16,8 @@ from colossalai.auto_parallel.tensor_shard.solver import (
...
@@ -16,8 +16,8 @@ from colossalai.auto_parallel.tensor_shard.solver import (
SolverOptions
,
SolverOptions
,
StrategiesConstructor
,
StrategiesConstructor
,
)
)
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.profile_alpha_beta
import
profile_alpha_beta
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
...
colossalai/device/__init__.py
View file @
9c9246c0
from
.alpha_beta_profiler
import
AlphaBetaProfiler
from
.calc_pipeline_strategy
import
alpa_dp
from
.calc_pipeline_strategy
import
alpa_dp
from
.profile_alpha_beta
import
profile_alpha_beta
__all__
=
[
'
profile_alpha_beta
'
,
'alpa_dp'
]
__all__
=
[
'
AlphaBetaProfiler
'
,
'alpa_dp'
]
colossalai/device/alpha_beta_profiler.py
0 → 100644
View file @
9c9246c0
import
math
import
time
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch.distributed
as
dist
from
colossalai.logging
import
get_dist_logger
GB
=
int
((
1
<<
30
))
BYTE
=
4
FRAMEWORK_LATENCY
=
0
class
AlphaBetaProfiler
:
'''
Profile alpha and beta value for a given device list.
Usage:
# Note: the environment of execution is supposed to be
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.profile_ab()
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
def
__init__
(
self
,
physical_devices
:
List
[
int
],
ctype
:
str
=
'a'
,
warmup
:
int
=
5
,
repeat
:
int
=
25
,
latency_iters
:
int
=
5
):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
self
.
physical_devices
=
physical_devices
self
.
ctype
=
ctype
self
.
world_size
=
len
(
physical_devices
)
self
.
warmup
=
warmup
self
.
repeat
=
repeat
self
.
latency_iters
=
latency_iters
self
.
process_group_dict
=
None
self
.
_init_profiling
()
def
_init_profiling
(
self
):
# Create process group list based on its global rank
process_group_list
=
[]
for
f_index
in
range
(
self
.
world_size
-
1
):
for
b_index
in
range
(
f_index
+
1
,
self
.
world_size
):
process_group_list
.
append
((
self
.
physical_devices
[
f_index
],
self
.
physical_devices
[
b_index
]))
# Create process group dict which maps process group to its handler
process_group_dict
=
{}
for
process_group
in
process_group_list
:
pg_handler
=
dist
.
new_group
(
process_group
)
process_group_dict
[
process_group
]
=
pg_handler
self
.
process_group_dict
=
process_group_dict
def
_profile
(
self
,
process_group
,
pg_handler
,
nbytes
):
logger
=
get_dist_logger
()
rank
=
dist
.
get_rank
()
src_device_num
=
process_group
[
0
]
world_size
=
len
(
process_group
)
device
=
torch
.
cuda
.
current_device
()
buf
=
torch
.
randn
(
nbytes
//
4
).
to
(
device
)
torch
.
cuda
.
synchronize
()
# warmup
for
_
in
range
(
self
.
warmup
):
if
self
.
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg_handler
)
elif
self
.
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
pg_handler
)
torch
.
cuda
.
synchronize
()
dist
.
barrier
(
group
=
pg_handler
)
begin
=
time
.
perf_counter
()
for
_
in
range
(
self
.
repeat
):
if
self
.
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg_handler
)
elif
self
.
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
pg_handler
)
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
dist
.
barrier
(
group
=
pg_handler
)
if
rank
==
src_device_num
:
avg_time_s
=
(
end
-
begin
)
/
self
.
repeat
-
FRAMEWORK_LATENCY
alg_band
=
nbytes
/
avg_time_s
if
self
.
ctype
==
"a"
:
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
bus_band
=
2
*
(
world_size
-
1
)
/
world_size
*
alg_band
bus_band
=
alg_band
elif
self
.
ctype
==
"b"
:
bus_band
=
alg_band
logger
.
info
(
f
"GPU:
{
rank
}
, Bytes:
{
nbytes
}
B,Time:
{
round
(
avg_time_s
*
1e6
,
2
)
}
us, Bus bandwidth:
{
round
(
bus_band
/
GB
,
2
)
}
GB/s"
)
return
(
avg_time_s
,
alg_band
)
else
:
# Just a placeholder
return
(
None
,
None
)
def
profile_latency
(
self
,
process_group
,
pg_handler
):
'''
This function is used to profile the latency of the given process group with a series of bytes.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
latency_list
=
[]
for
i
in
range
(
self
.
latency_iters
):
nbytes
=
int
(
BYTE
<<
i
)
(
t
,
_
)
=
self
.
_profile
(
process_group
,
pg_handler
,
nbytes
)
latency_list
.
append
(
t
)
if
latency_list
[
0
]
is
None
:
latency
=
None
else
:
median_index
=
math
.
floor
(
self
.
latency_iters
/
2
)
latency
=
latency_list
[
median_index
]
return
latency
def
profile_bandwidth
(
self
,
process_group
,
pg_handler
,
maxbytes
):
'''
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
(
_
,
bandwidth
)
=
self
.
_profile
(
process_group
,
pg_handler
,
maxbytes
)
return
bandwidth
def
profile_ab
(
self
):
'''
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
alpha_beta_dict
:
Dict
[
Tuple
[
int
],
Tuple
[
float
]]
=
{}
rank
=
dist
.
get_rank
()
def
get_max_nbytes
(
process_group
:
Tuple
[
int
],
pg_handler
:
dist
.
ProcessGroup
):
assert
rank
in
process_group
device
=
torch
.
cuda
.
current_device
()
rank_max_nbytes
=
torch
.
cuda
.
mem_get_info
(
device
)[
0
]
rank_max_nbytes
=
torch
.
tensor
(
rank_max_nbytes
,
device
=
device
)
dist
.
all_reduce
(
rank_max_nbytes
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
pg_handler
)
max_nbytes
=
min
(
int
(
1
*
GB
),
int
(
GB
<<
int
(
math
.
log2
(
rank_max_nbytes
.
item
()
/
GB
))))
return
max_nbytes
for
process_group
,
pg_handler
in
self
.
process_group_dict
.
items
():
if
rank
not
in
process_group
:
max_nbytes
=
None
alpha
=
None
bandwidth
=
None
else
:
max_nbytes
=
get_max_nbytes
(
process_group
,
pg_handler
)
alpha
=
self
.
profile_latency
(
process_group
,
pg_handler
)
bandwidth
=
self
.
profile_bandwidth
(
process_group
,
pg_handler
,
maxbytes
=
max_nbytes
)
if
bandwidth
is
None
:
beta
=
None
else
:
beta
=
1
/
bandwidth
broadcast_list
=
[
alpha
,
beta
]
dist
.
broadcast_object_list
(
broadcast_list
,
src
=
process_group
[
0
])
alpha_beta_dict
[
process_group
]
=
tuple
(
broadcast_list
)
# add symmetry pair to the apha_beta_dict
symmetry_ab_dict
=
{}
for
process_group
,
alpha_beta_pair
in
alpha_beta_dict
.
items
():
symmetry_process_group
=
(
process_group
[
1
],
process_group
[
0
])
symmetry_ab_dict
[
symmetry_process_group
]
=
alpha_beta_pair
alpha_beta_dict
.
update
(
symmetry_ab_dict
)
return
alpha_beta_dict
colossalai/device/profile_alpha_beta.py
deleted
100644 → 0
View file @
f1bc2418
import
fcntl
import
math
import
os
import
time
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
MB
=
int
((
1
<<
10
)
*
1e3
)
GB
=
int
((
1
<<
20
)
*
1e3
)
Byte
=
4
FRAMEWORK
=
0
NON_SENSE
=
(
0.1
,
0.1
)
def
printflock
(
*
msgs
):
""" solves multi-process interleaved print problem """
with
open
(
__file__
,
"r"
)
as
fh
:
fcntl
.
flock
(
fh
,
fcntl
.
LOCK_EX
)
try
:
print
(
*
msgs
)
finally
:
fcntl
.
flock
(
fh
,
fcntl
.
LOCK_UN
)
def
profile
(
device1d
,
nbytes
,
ctype
):
warmup
=
5
repeat
=
25
rank
=
dist
.
get_rank
()
src_device_num
=
device1d
[
0
]
wsize
=
len
(
device1d
)
group
=
dist
.
new_group
(
device1d
)
torch
.
cuda
.
set_device
(
rank
)
device
=
torch
.
device
(
"cuda"
,
rank
)
buf
=
torch
.
randn
(
nbytes
//
4
).
to
(
device
)
torch
.
cuda
.
synchronize
()
# warmup
for
_
in
range
(
warmup
):
if
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
elif
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
group
)
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
begin
=
time
.
perf_counter
()
for
_
in
range
(
repeat
):
if
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
elif
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
group
)
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
dist
.
barrier
()
if
rank
==
src_device_num
:
avg_time_s
=
(
end
-
begin
)
/
repeat
-
FRAMEWORK
alg_band
=
nbytes
/
avg_time_s
if
ctype
==
"b"
:
bus_band
=
alg_band
elif
ctype
==
"a"
:
bus_band
=
2
*
(
wsize
-
1
)
/
wsize
*
alg_band
print
(
f
"GPU:
{
rank
}
, Bytes:
{
nbytes
}
B,Time:
{
round
(
avg_time_s
*
1e6
,
2
)
}
us, Bus bandwidth:
{
round
(
bus_band
/
GB
,
2
)
}
GB/s"
)
return
(
avg_time_s
,
alg_band
)
else
:
return
NON_SENSE
# Just a placeholder
def
profile_latency
(
device1d
,
it
=
3
,
ctype
=
"a"
):
latency
=
[]
for
i
in
range
(
it
):
nbytes
=
int
(
Byte
<<
i
)
(
t
,
_
)
=
profile
(
device1d
,
nbytes
,
ctype
)
latency
.
append
(
t
)
return
min
(
latency
)
def
profile_bandwidth
(
device1d
,
maxbytes
,
ctype
=
"a"
):
(
_
,
bandwidth
)
=
profile
(
device1d
,
maxbytes
,
ctype
)
return
bandwidth
def
profile_ab
(
rank
,
*
args
):
wsize
=
int
(
torch
.
cuda
.
device_count
())
device1d
=
args
[
0
]
return_dict
=
args
[
1
]
ctype
=
args
[
2
]
os
.
environ
[
'MASTER_ADDR'
]
=
'localhost'
os
.
environ
[
'MASTER_PORT'
]
=
'29020'
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
NCCL
,
init_method
=
'env://'
,
world_size
=
wsize
,
rank
=
rank
)
device
=
torch
.
device
(
"cuda"
,
rank
)
max_nbytes
=
torch
.
tensor
(
torch
.
cuda
.
mem_get_info
(
device
)[
0
]).
to
(
device
)
max_nbytes
=
min
(
int
(
4
*
GB
),
int
(
GB
<<
int
(
math
.
log2
(
max_nbytes
.
item
()
/
GB
))))
if
rank
==
device1d
[
0
]:
print
(
f
"max_nbytes:
{
max_nbytes
}
B"
)
alpha
=
profile_latency
(
device1d
,
it
=
5
,
ctype
=
ctype
)
beta
=
1
/
profile_bandwidth
(
device1d
,
maxbytes
=
max_nbytes
,
ctype
=
ctype
)
if
rank
==
device1d
[
0
]:
print
(
f
"alpha(us):
{
round
(
alpha
*
1e6
,
2
)
}
, beta(us/GB):
{
round
(
beta
*
1e6
*
GB
,
2
)
}
"
)
return_dict
[
rank
]
=
(
alpha
,
beta
)
def
profile_alpha_beta
(
device1d
):
assert
torch
.
cuda
.
is_available
()
assert
len
(
device1d
)
>
0
and
len
(
device1d
)
<=
int
(
torch
.
cuda
.
device_count
())
manager
=
mp
.
Manager
()
return_dict
=
manager
.
dict
()
ctype
=
"a"
mp
.
spawn
(
profile_ab
,
args
=
[
device1d
,
return_dict
,
ctype
],
nprocs
=
int
(
torch
.
cuda
.
device_count
()))
return
return_dict
[
device1d
[
0
]]
tests/test_device/test_alpha_beta.py
View file @
9c9246c0
from
functools
import
partial
import
pytest
import
pytest
import
torch.multiprocessing
as
mp
from
colossalai.device
import
profile_alpha_beta
from
colossalai.device
import
AlphaBetaProfiler
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
@
pytest
.
mark
.
skip
(
reason
=
"Skip because assertion fails for CI devices"
)
def
check_alpha_beta
(
rank
,
physical_devices
,
world_size
,
port
):
def
test_profile_alpha_beta
():
disable_existing_loggers
()
physical_devices
=
[
0
,
1
,
2
,
3
]
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
(
alpha
,
beta
)
=
profile_alpha_beta
(
physical_devices
)
profiler
=
AlphaBetaProfiler
(
physical_devices
)
ab_dict
=
profiler
.
profile_ab
()
for
_
,
(
alpha
,
beta
)
in
ab_dict
.
items
():
assert
alpha
>
0
and
alpha
<
1e-4
and
beta
>
0
and
beta
<
1e-10
assert
alpha
>
0
and
alpha
<
1e-4
and
beta
>
0
and
beta
<
1e-10
@
pytest
.
mark
.
skip
(
reason
=
"Skip because assertion fails for CI devices"
)
@
pytest
.
mark
.
dist
@
parameterize
(
'physical_devices'
,
[[
0
,
1
,
2
,
3
],
[
0
,
3
]])
@
rerun_if_address_is_in_use
()
def
test_profile_alpha_beta
(
physical_devices
):
world_size
=
4
run_func
=
partial
(
check_alpha_beta
,
physical_devices
=
physical_devices
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_profile_alpha_beta
()
test_profile_alpha_beta
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment