Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
425bb0df
Commit
425bb0df
authored
Mar 09, 2022
by
HELSON
Committed by
Frank Lee
Mar 11, 2022
Browse files
Added Profiler Context to manage all profilers (#340)
parent
d0ae0f22
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
193 additions
and
118 deletions
+193
-118
colossalai/utils/profiler/__init__.py
colossalai/utils/profiler/__init__.py
+2
-1
colossalai/utils/profiler/comm_profiler.py
colossalai/utils/profiler/comm_profiler.py
+106
-77
colossalai/utils/profiler/prof_utils.py
colossalai/utils/profiler/prof_utils.py
+85
-0
tests/test_profiler/test_comm_prof.py
tests/test_profiler/test_comm_prof.py
+0
-40
No files found.
colossalai/utils/profiler/__init__.py
View file @
425bb0df
from
.comm_profiler
import
enable_communication_prof
,
communication_prof_show
from
.comm_profiler
import
CommProfiler
from
.prof_utils
import
ProfilerContext
colossalai/utils/profiler/comm_profiler.py
View file @
425bb0df
import
inspect
from
pathlib
import
Path
from
functools
import
partial
import
torch
from
torch.autograd.profiler
import
profile
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
import
torch.utils.tensorboard
as
tb
from
colossalai.utils
import
get_current_device
from
.prof_utils
import
BaseProfiler
from
typing
import
List
,
Optional
...
...
@@ -57,6 +61,13 @@ def _format_bandwith(volme: float, time_us: int):
return
'{:.3f} MB/s'
.
format
(
mb_per_sec
)
torch_all_reduce
=
dist
.
all_reduce
torch_all_gather
=
dist
.
all_gather
torch_reduce_scatter
=
dist
.
reduce_scatter
torch_broadcast
=
dist
.
broadcast
torch_reduce
=
dist
.
reduce
class
CommEvent
(
object
):
"""Communication Event. Used for communication time and communication
volume recording.
...
...
@@ -73,16 +84,16 @@ class CommEvent(object):
self
.
self_cuda_time
+=
rhs
.
self_cuda_time
class
CommProfiler
(
object
):
class
CommProfiler
(
BaseProfiler
):
"""Communication profiler. Records all communication events.
"""
def
__init__
(
self
,
total_count
:
int
=
0
,
total_comm_vol
:
float
=
0
,
total_cuda_time
:
int
=
0
,
prof_depth
:
int
=
3
):
super
().
__init__
()
def
__init__
(
self
,
depth
:
int
=
0
,
total_count
:
int
=
0
,
total_comm_vol
:
float
=
0
,
total_cuda_time
:
int
=
0
):
super
().
__init__
(
profiler_name
=
"Collective_Communication"
,
priority
=
0
)
self
.
depth
=
3
+
depth
self
.
total_count
=
total_count
self
.
total_comm_vol
=
total_comm_vol
self
.
total_cuda_time
=
total_cuda_time
self
.
depth
=
prof_depth
self
.
ops_record
=
dict
()
self
.
profiler
=
None
...
...
@@ -101,27 +112,58 @@ class CommProfiler(object):
self
.
pending_metadata
=
None
self
.
warn_flag
=
False
def
enable
(
self
):
dist
.
all_reduce
=
partial
(
all_reduce
,
profiler
=
self
)
dist
.
all_gather
=
partial
(
all_gather
,
profiler
=
self
)
dist
.
reduce_scatter
=
partial
(
reduce_scatter
,
profiler
=
self
)
dist
.
broadcast
=
partial
(
broadcast
,
profiler
=
self
)
dist
.
reduce
=
partial
(
reduce
,
profiler
=
self
)
def
disable
(
self
):
dist
.
all_reduce
=
torch_all_reduce
dist
.
all_gather
=
torch_all_gather
dist
.
reduce_scatter
=
torch_reduce_scatter
dist
.
broadcast
=
torch_broadcast
dist
.
reduce
=
torch_reduce
def
to_tensorboard
(
self
,
writer
:
tb
.
writer
):
writer
.
add_text
(
tag
=
"Collective Communication"
,
text_string
=
self
.
result_list
(
"
\n\n
"
))
def
to_file
(
self
,
filename
:
Path
):
with
open
(
filename
,
"w"
)
as
f
:
f
.
write
(
self
.
result_list
())
def
show
(
self
):
print
(
self
.
result_list
())
def
result_list
(
self
,
sep
:
str
=
"
\n
"
):
res
=
[]
def
append
(
s
:
str
):
res
.
append
(
s
)
res
.
append
(
sep
)
if
self
.
warn_flag
:
print
(
"Warnning: there exists multiple communication operations in the same time.
\n
"
"As a result,
the profiling result is not accurate."
)
print
(
"Collective communication profiling result:"
,
"total cuda time: {}"
.
format
(
_format_time
(
self
.
total_cuda_time
)),
"average bandwith
: {}"
.
format
(
_format_
bandwith
(
self
.
total_comm_vol
,
self
.
total_cuda_time
))
,
"total number of calls: {}"
.
format
(
self
.
total_count
),
"All events:"
,
sep
=
'
\n
'
)
append
(
"Warnning: there exists multiple communication operations in the same time.
As a result,
"
"
the profiling result is not accurate."
)
append
(
"Collective communication profiling result:"
)
append
(
"total cuda time
: {}"
.
format
(
_format_
time
(
self
.
total_cuda_time
))
)
append
(
"average bandwith: {}"
.
format
(
_format_bandwith
(
self
.
total_comm_vol
,
self
.
total_cuda_time
)))
append
(
"total number of calls: {}"
.
format
(
self
.
total_count
))
append
(
"All events:
\n
----------------------------------------"
)
show_list
=
sorted
(
self
.
ops_record
.
items
(),
key
=
lambda
kv
:
-
kv
[
1
].
self_cuda_time
)
for
location
,
event
in
show_list
:
print
(
location
,
"self cuda time: {}"
.
format
(
_format_time
(
event
.
self_cuda_time
)),
"{:.1f}% of total communication time"
.
format
(
event
.
self_cuda_time
/
self
.
total_cuda_time
*
100.0
),
"self communication volme: {}"
.
format
(
_format_memory
(
event
.
self_comm_vol
)),
"average bandwith: {}"
.
format
(
_format_bandwith
(
event
.
self_comm_vol
,
event
.
self_cuda_time
)),
"number of calls: {}"
.
format
(
event
.
self_count
),
"--------------------"
,
sep
=
'
\n
'
)
append
(
location
)
append
(
"self cuda time: {}"
.
format
(
_format_time
(
event
.
self_cuda_time
)))
append
(
"{:.1f}% of total communication time"
.
format
(
event
.
self_cuda_time
/
self
.
total_cuda_time
*
100.0
))
append
(
"self communication volme: {}"
.
format
(
_format_memory
(
event
.
self_comm_vol
)))
append
(
"average bandwith: {}"
.
format
(
_format_bandwith
(
event
.
self_comm_vol
,
event
.
self_cuda_time
)))
append
(
"number of calls: {}"
.
format
(
event
.
self_count
))
append
(
"----------------------------------------"
)
return
''
.
join
(
res
)
@
property
def
has_aync_op
(
self
):
...
...
@@ -176,65 +218,46 @@ class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations.
"""
def
__init__
(
self
):
def
__init__
(
self
,
profiler
:
CommProfiler
):
super
().
__init__
()
self
.
prof
=
COL_COMM_PROF
self
.
prof
=
profiler
def
wait
(
self
):
self
.
prof
.
wait_async_op
()
COL_COMM_PROF
=
CommProfiler
()
torch_all_reduce
=
dist
.
all_reduce
torch_all_gather
=
dist
.
all_gather
torch_reduce_scatter
=
dist
.
reduce_scatter
torch_broadcast
=
dist
.
broadcast
torch_reduce
=
dist
.
reduce
def
enable_communication_prof
(
depth
:
int
=
0
):
COL_COMM_PROF
.
depth
=
3
+
depth
dist
.
all_reduce
=
all_reduce
dist
.
all_gather
=
all_gather
dist
.
reduce_scatter
=
reduce_scatter
dist
.
broadcast
=
broadcast
dist
.
reduce
=
reduce
def
communication_prof_show
():
COL_COMM_PROF
.
show
()
def
async_check
():
if
COL_COMM_PROF
.
pending_op
is
not
None
:
COL_COMM_PROF
.
warn_flag
=
True
COL_COMM_PROF
.
wait_async_op
()
def
async_check
(
profiler
:
CommProfiler
):
if
profiler
.
pending_op
is
not
None
:
profiler
.
warn_flag
=
True
profiler
.
wait_async_op
()
def
all_reduce
(
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
async_op
:
bool
=
False
,
profiler
:
CommProfiler
=
None
)
->
Optional
[
CommHandler
]:
async_check
(
profiler
)
comm_size
=
dist
.
get_world_size
(
group
)
correction
=
2
*
(
comm_size
-
1
)
/
comm_size
comm_vol
=
correction
*
tensor
.
element_size
()
*
tensor
.
numel
()
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_AllReduce_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_all_reduce
(
tensor
,
op
,
group
,
async_op
)
profiler
.
activate_profiler
(
"ncclKernel_AllReduce_"
,
comm_vol
)
profiler
.
pending_op
=
torch_all_reduce
(
tensor
,
op
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
return
CommHandler
(
profiler
)
COL_COMM_PROF
.
close_profiler
(
group
)
profiler
.
close_profiler
(
group
)
def
reduce_scatter
(
output
:
torch
.
Tensor
,
input_list
:
List
[
torch
.
Tensor
],
op
:
ReduceOp
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
async_op
:
bool
=
False
,
profiler
:
CommProfiler
=
None
)
->
Optional
[
CommHandler
]:
async_check
(
profiler
)
comm_size
=
dist
.
get_world_size
(
group
)
correction
=
(
comm_size
-
1
)
/
comm_size
...
...
@@ -242,20 +265,21 @@ def reduce_scatter(output: torch.Tensor,
for
tensor
in
input_list
:
comm_vol
+=
tensor
.
element_size
()
*
tensor
.
numel
()
comm_vol
*=
correction
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_ReduceScatter_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_reduce_scatter
(
output
,
input_list
,
op
,
group
,
async_op
)
profiler
.
activate_profiler
(
"ncclKernel_ReduceScatter_"
,
comm_vol
)
profiler
.
pending_op
=
torch_reduce_scatter
(
output
,
input_list
,
op
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
return
CommHandler
(
profiler
)
COL_COMM_PROF
.
close_profiler
(
group
)
profiler
.
close_profiler
(
group
)
def
all_gather
(
tensor_list
:
List
[
torch
.
Tensor
],
tensor
:
torch
.
Tensor
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
async_op
:
bool
=
False
,
profiler
:
CommProfiler
=
None
)
->
Optional
[
CommHandler
]:
async_check
(
profiler
)
comm_size
=
dist
.
get_world_size
(
group
)
correction
=
(
comm_size
-
1
)
/
comm_size
...
...
@@ -263,40 +287,45 @@ def all_gather(tensor_list: List[torch.Tensor],
for
ten
in
tensor_list
:
comm_vol
+=
ten
.
element_size
()
*
ten
.
numel
()
comm_vol
*=
correction
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_AllGather_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_all_gather
(
tensor_list
,
tensor
,
group
,
async_op
)
profiler
.
activate_profiler
(
"ncclKernel_AllGather_"
,
comm_vol
)
profiler
.
pending_op
=
torch_all_gather
(
tensor_list
,
tensor
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
return
CommHandler
(
profiler
)
COL_COMM_PROF
.
close_profiler
(
group
)
profiler
.
close_profiler
(
group
)
def
broadcast
(
tensor
:
torch
.
Tensor
,
src
:
int
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
def
broadcast
(
tensor
:
torch
.
Tensor
,
src
:
int
,
group
=
None
,
async_op
:
bool
=
False
,
profiler
:
CommProfiler
=
None
)
->
Optional
[
CommHandler
]:
async_check
(
profiler
)
comm_vol
=
1.0
*
tensor
.
element_size
()
*
tensor
.
numel
()
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_Broadcast_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_broadcast
(
tensor
,
src
,
group
,
async_op
)
profiler
.
activate_profiler
(
"ncclKernel_Broadcast_"
,
comm_vol
)
profiler
.
pending_op
=
torch_broadcast
(
tensor
,
src
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
return
CommHandler
(
profiler
)
COL_COMM_PROF
.
close_profiler
(
group
)
profiler
.
close_profiler
(
group
)
def
reduce
(
tensor
:
torch
.
Tensor
,
dst
:
int
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
async_op
:
bool
=
False
,
profiler
:
CommProfiler
=
None
)
->
Optional
[
CommHandler
]:
async_check
(
profiler
)
comm_vol
=
1.0
*
tensor
.
element_size
()
*
tensor
.
numel
()
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_Reduce_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_reduce
(
tensor
,
dst
,
op
,
group
,
async_op
)
profiler
.
activate_profiler
(
"ncclKernel_Reduce_"
,
comm_vol
)
profiler
.
pending_op
=
torch_reduce
(
tensor
,
dst
,
op
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
return
CommHandler
(
profiler
)
COL_COMM_PROF
.
close_profiler
(
group
)
profiler
.
close_profiler
(
group
)
colossalai/utils/profiler/prof_utils.py
0 → 100644
View file @
425bb0df
from
abc
import
ABC
,
abstractmethod
from
pathlib
import
Path
from
typing
import
Union
,
List
from
colossalai.core
import
global_context
as
gpc
class
BaseProfiler
(
ABC
):
def
__init__
(
self
,
profiler_name
:
str
,
priority
:
int
):
self
.
name
=
profiler_name
self
.
priority
=
priority
@
abstractmethod
def
enable
(
self
):
pass
@
abstractmethod
def
disable
(
self
):
pass
@
abstractmethod
def
to_tensorboard
(
self
,
writer
):
pass
@
abstractmethod
def
to_file
(
self
,
filename
:
Path
):
pass
@
abstractmethod
def
show
(
self
):
pass
class
ProfilerContext
(
object
):
"""
Profiler context manager
Usage:
from colossalai.utils.profiler import CommProf, ProfilerContext
from torch.utils.tensorboard import SummaryWriter
cc_prof = CommProf()
with ProfilerContext([cc_prof]) as prof:
train()
writer = SummaryWriter('tb/path')
prof.to_tensorboard(writer)
prof.to_file('./prof_logs/')
prof.show()
"""
def
__init__
(
self
,
profilers
:
List
[
BaseProfiler
]
=
None
,
enable
:
bool
=
True
):
self
.
enable
=
enable
self
.
profilers
=
sorted
(
profilers
,
key
=
lambda
prof
:
prof
.
priority
)
def
__enter__
(
self
):
if
self
.
enable
:
for
prof
in
self
.
profilers
:
prof
.
enable
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
self
.
enable
:
for
prof
in
self
.
profilers
:
prof
.
disable
()
def
to_tensorboard
(
self
,
writer
):
from
torch.utils.tensorboard
import
SummaryWriter
assert
isinstance
(
writer
,
SummaryWriter
),
\
f
'torch.utils.tensorboard.SummaryWriter is required, but found
{
type
(
writer
)
}
.'
for
prof
in
self
.
profilers
:
prof
.
to_tensorboard
(
writer
)
def
to_file
(
self
,
log_dir
:
Union
[
str
,
Path
]):
if
isinstance
(
log_dir
,
str
):
log_dir
=
Path
(
log_dir
)
if
not
log_dir
.
exists
():
log_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
for
prof
in
self
.
profilers
:
log_file
=
log_dir
.
joinpath
(
f
'
{
prof
.
name
}
_rank_
{
gpc
.
get_global_rank
()
}
.log'
)
prof
.
to_file
(
log_file
)
def
show
(
self
):
for
prof
in
self
.
profilers
:
prof
.
show
()
tests/test_profiler/test_comm_prof.py
deleted
100644 → 0
View file @
d0ae0f22
from
functools
import
partial
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
colossalai
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.profiler
import
enable_communication_prof
,
communication_prof_show
BATCH_SIZE
=
1024
D_MODEL
=
1024
CONFIG
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
'1d'
,
size
=
4
)))
def
run_test
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
inputs
=
torch
.
randn
(
BATCH_SIZE
,
D_MODEL
,
dtype
=
torch
.
float32
,
device
=
get_current_device
())
outputs
=
torch
.
empty
(
world_size
,
BATCH_SIZE
,
D_MODEL
,
dtype
=
torch
.
float32
,
device
=
get_current_device
())
outputs_list
=
list
(
torch
.
chunk
(
outputs
,
chunks
=
world_size
,
dim
=
0
))
enable_communication_prof
()
op
=
dist
.
all_reduce
(
inputs
,
async_op
=
True
)
dist
.
all_gather
(
outputs_list
,
inputs
)
op
.
wait
()
dist
.
reduce_scatter
(
inputs
,
outputs_list
)
dist
.
broadcast
(
inputs
,
0
)
dist
.
reduce
(
inputs
,
0
)
if
rank
==
0
:
communication_prof_show
()
def
test_cc_prof
():
world_size
=
4
run_func
=
partial
(
run_test
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_cc_prof
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment