Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
73bff112
Commit
73bff112
authored
Mar 04, 2022
by
1SAA
Committed by
Frank Lee
Mar 11, 2022
Browse files
Added profiler communication operations
Fixed bug for learning rate scheduler
parent
d275b98b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
368 additions
and
18 deletions
+368
-18
colossalai/communication/__init__.py
colossalai/communication/__init__.py
+20
-11
colossalai/trainer/hooks/_lr_scheduler_hook.py
colossalai/trainer/hooks/_lr_scheduler_hook.py
+1
-0
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+4
-7
colossalai/utils/profiler/__init__.py
colossalai/utils/profiler/__init__.py
+1
-0
colossalai/utils/profiler/comm_profiler.py
colossalai/utils/profiler/comm_profiler.py
+302
-0
tests/test_profiler/test_comm_prof.py
tests/test_profiler/test_comm_prof.py
+40
-0
No files found.
colossalai/communication/__init__.py
View file @
73bff112
from
.collective
import
all_gather
,
reduce_scatter
,
all_reduce
,
broadcast
,
reduce
from
.collective
import
all_gather
,
reduce_scatter
,
all_reduce
,
broadcast
,
reduce
from
.p2p
import
(
send_forward
,
send_forward_recv_forward
,
from
.p2p
import
(
send_forward
,
send_forward_recv_forward
,
send_backward_recv_forward
,
send_backward
,
send_backward_recv_forward
,
send_backward
,
send_backward_recv_backward
,
send_forward_recv_backward
,
send_forward_backward_recv_forward_backward
,
send_backward_recv_backward
,
send_forward_recv_backward
,
recv_forward
,
recv_backward
)
send_forward_backward_recv_forward_backward
,
recv_forward
,
recv_backward
)
from
.ring
import
ring_forward
from
.ring
import
ring_forward
from
.utils
import
send_tensor_meta
,
recv_tensor_meta
from
.utils
import
send_tensor_meta
,
recv_tensor_meta
__all__
=
[
__all__
=
[
'all_gather'
,
'reduce_scatter'
,
'all_reduce'
,
'broadcast'
,
'reduce'
,
'all_gather'
,
'send_forward'
,
'send_forward_recv_forward'
,
'reduce_scatter'
,
'send_forward_backward_recv_forward_backward'
,
'send_backward'
,
'all_reduce'
,
'send_backward_recv_backward'
,
'send_backward_recv_forward'
,
'broadcast'
,
'send_forward_recv_backward'
,
'recv_backward'
,
'recv_forward'
,
'reduce'
,
'ring_forward'
,
'send_tensor_meta'
,
'recv_tensor_meta'
,
'send_forward'
,
'send_forward_recv_forward'
,
'send_forward_backward_recv_forward_backward'
,
'send_backward'
,
'send_backward_recv_backward'
,
'send_backward_recv_forward'
,
'send_forward_recv_backward'
,
'recv_backward'
,
'recv_forward'
,
'ring_forward'
,
'send_tensor_meta'
,
'recv_tensor_meta'
,
]
]
colossalai/trainer/hooks/_lr_scheduler_hook.py
View file @
73bff112
...
@@ -29,6 +29,7 @@ class LRSchedulerHook(MetricHook):
...
@@ -29,6 +29,7 @@ class LRSchedulerHook(MetricHook):
self
.
store_lr_in_state
=
store_lr_in_state
self
.
store_lr_in_state
=
store_lr_in_state
def
after_hook_is_attached
(
self
,
trainer
):
def
after_hook_is_attached
(
self
,
trainer
):
self
.
_check_metric_states_initialization
(
trainer
)
trainer
.
states
[
'metrics'
][
'train'
][
'LR'
]
=
LearningRateMetric
(
epoch_only
=
self
.
by_epoch
,
trainer
.
states
[
'metrics'
][
'train'
][
'LR'
]
=
LearningRateMetric
(
epoch_only
=
self
.
by_epoch
,
initial_lr
=
self
.
lr_scheduler
.
get_last_lr
()[
0
])
initial_lr
=
self
.
lr_scheduler
.
get_last_lr
()[
0
])
...
...
colossalai/utils/__init__.py
View file @
73bff112
from
.activation_checkpoint
import
checkpoint
from
.activation_checkpoint
import
checkpoint
from
.common
import
(
clip_grad_norm_fp32
,
conditional_context
,
from
.common
import
(
clip_grad_norm_fp32
,
conditional_context
,
copy_tensor_parallel_attributes
,
count_zeros_fp32
,
copy_tensor_parallel_attributes
,
count_zeros_fp32
,
free_port
,
is_dp_rank_0
,
is_model_parallel_parameter
,
is_moe_parallel_parameter
,
free_port
,
is_dp_rank_0
,
is_model_parallel_parameter
,
is_no_pp_or_last_stage
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
is_moe_parallel_parameter
,
is_no_pp_or_last_stage
,
multi_tensor_applier
,
param_is_not_tensor_parallel_duplicate
,
print_rank_0
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
multi_tensor_applier
,
param_is_not_tensor_parallel_duplicate
,
print_rank_0
,
switch_virtual_pipeline_parallel_rank
,
sync_model_param
)
switch_virtual_pipeline_parallel_rank
,
sync_model_param
)
from
.cuda
import
empty_cache
,
get_current_device
,
set_to_cuda
,
synchronize
from
.cuda
import
empty_cache
,
get_current_device
,
set_to_cuda
,
synchronize
from
.data_sampler
import
DataParallelSampler
,
get_dataloader
from
.data_sampler
import
DataParallelSampler
,
get_dataloader
...
...
colossalai/utils/profiler/__init__.py
0 → 100644
View file @
73bff112
from
.comm_profiler
import
enable_communication_prof
,
communication_prof_show
colossalai/utils/profiler/comm_profiler.py
0 → 100644
View file @
73bff112
import
inspect
import
torch
from
torch.autograd.profiler
import
profile
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
from
colossalai.utils
import
get_current_device
from
typing
import
List
,
Optional
def
_get_code_location
(
depth
:
int
):
ret
=
""
length
=
len
(
inspect
.
stack
())
for
i
in
range
(
3
,
min
(
length
,
depth
+
1
)):
upper_frame
=
inspect
.
stack
()[
i
]
function_name
=
inspect
.
stack
()[
i
-
1
].
function
info
=
upper_frame
.
filename
+
"("
+
str
(
upper_frame
.
lineno
)
+
"): "
+
function_name
+
"
\n
"
ret
+=
info
return
ret
# copied from high version pytorch to support low version
def
_format_time
(
time_us
):
"""Defines how to format time in FunctionEvent"""
US_IN_SECOND
=
1000.0
*
1000.0
US_IN_MS
=
1000.0
if
time_us
>=
US_IN_SECOND
:
return
'{:.3f}s'
.
format
(
time_us
/
US_IN_SECOND
)
if
time_us
>=
US_IN_MS
:
return
'{:.3f}ms'
.
format
(
time_us
/
US_IN_MS
)
return
'{:.3f}us'
.
format
(
time_us
)
# copied from high version pytorch to support low version
def
_format_memory
(
nbytes
):
"""Returns a formatted memory size string"""
KB
=
1024
MB
=
1024
*
KB
GB
=
1024
*
MB
if
(
abs
(
nbytes
)
>=
GB
):
return
'{:.2f} Gb'
.
format
(
nbytes
*
1.0
/
GB
)
elif
(
abs
(
nbytes
)
>=
MB
):
return
'{:.2f} Mb'
.
format
(
nbytes
*
1.0
/
MB
)
elif
(
abs
(
nbytes
)
>=
KB
):
return
'{:.2f} Kb'
.
format
(
nbytes
*
1.0
/
KB
)
else
:
return
str
(
nbytes
)
+
' b'
def
_format_bandwith
(
volme
:
float
,
time_us
:
int
):
sec_div_mb
=
(
1000.0
/
1024.0
)
**
2
mb_per_sec
=
volme
/
time_us
*
sec_div_mb
if
mb_per_sec
>=
1024.0
:
return
'{:.3f} Gb/s'
.
format
(
mb_per_sec
/
1024.0
)
else
:
return
'{:.3f} Mb/s'
.
format
(
mb_per_sec
)
class
CommEvent
(
object
):
"""Communication Event. Used for communication time and communication
volume recording.
"""
def
__init__
(
self
,
count
:
int
=
0
,
comm_vol
:
float
=
0.
,
cuda_time
:
int
=
0
):
self
.
self_count
=
count
self
.
self_comm_vol
=
comm_vol
self
.
self_cuda_time
=
cuda_time
def
add
(
self
,
rhs
):
self
.
self_count
+=
rhs
.
self_count
self
.
self_comm_vol
+=
rhs
.
self_comm_vol
self
.
self_cuda_time
+=
rhs
.
self_cuda_time
class
CommProfiler
(
object
):
"""Communication profiler. Records all communication events.
"""
def
__init__
(
self
,
total_count
:
int
=
0
,
total_comm_vol
:
float
=
0
,
total_cuda_time
:
int
=
0
,
prof_depth
:
int
=
3
):
super
().
__init__
()
self
.
total_count
=
total_count
self
.
total_comm_vol
=
total_comm_vol
self
.
total_cuda_time
=
total_cuda_time
self
.
depth
=
prof_depth
self
.
ops_record
=
dict
()
self
.
profiler
=
None
self
.
pending_op
=
None
self
.
pending_metadata
=
None
self
.
warn_flag
=
False
def
reset
(
self
):
self
.
total_count
=
0
self
.
total_comm_vol
=
0
self
.
total_cuda_time
=
0
self
.
ops_record
=
dict
()
self
.
profiler
=
None
self
.
pending_op
=
None
self
.
pending_metadata
=
None
self
.
warn_flag
=
False
def
show
(
self
):
if
self
.
warn_flag
:
print
(
"Warnning: there exists multiple communication operations in the same time.
\n
"
"As a result, the profiling result is not accurate."
)
print
(
"Collective communication profiling result:"
,
"total cuda time: {}"
.
format
(
_format_time
(
self
.
total_cuda_time
)),
"average bandwith: {}"
.
format
(
_format_bandwith
(
self
.
total_comm_vol
,
self
.
total_cuda_time
)),
"total number of calls: {}"
.
format
(
self
.
total_count
),
"All events:"
,
sep
=
'
\n
'
)
show_list
=
sorted
(
self
.
ops_record
.
items
(),
key
=
lambda
kv
:
-
kv
[
1
].
self_cuda_time
)
for
location
,
event
in
show_list
:
print
(
location
,
"self cuda time: {}"
.
format
(
_format_time
(
event
.
self_cuda_time
)),
"{:.1f}% of total communication time"
.
format
(
event
.
self_cuda_time
/
self
.
total_cuda_time
*
100.0
),
"self communication volme: {}"
.
format
(
_format_memory
(
event
.
self_comm_vol
)),
"average bandwith: {}"
.
format
(
_format_bandwith
(
event
.
self_comm_vol
,
event
.
self_cuda_time
)),
"number of calls: {}"
.
format
(
event
.
self_count
),
"--------------------"
,
sep
=
'
\n
'
)
@
property
def
has_aync_op
(
self
):
return
self
.
pending_op
is
not
None
def
activate_profiler
(
self
,
kn
:
str
,
vol
:
float
):
self
.
pending_metadata
=
(
kn
,
_get_code_location
(
self
.
depth
),
vol
)
self
.
profiler
=
profile
(
enabled
=
True
,
use_cuda
=
True
,
use_cpu
=
True
,
use_kineto
=
True
)
self
.
profiler
.
__enter__
()
def
close_profiler
(
self
,
group
=
None
):
assert
self
.
profiler
is
not
None
,
"There is no running dist op"
kernel_name
,
code_location
,
vol
=
self
.
pending_metadata
self
.
profiler
.
__exit__
(
None
,
None
,
None
)
if
self
.
profiler
.
enabled
:
assert_flag
=
0
current_comm_event
=
None
events
=
self
.
profiler
.
function_events
for
event
in
events
:
if
kernel_name
in
event
.
name
:
assert
assert_flag
==
0
,
"Multiple dist ops has been called "
current_comm_event
=
CommEvent
(
1
,
vol
,
event
.
self_cuda_time_total
)
assert_flag
+=
1
assert
current_comm_event
is
not
None
,
"dist op has not been found"
buffer
=
torch
.
tensor
([
current_comm_event
.
self_cuda_time
],
device
=
get_current_device
())
torch_all_reduce
(
buffer
,
op
=
ReduceOp
.
MIN
,
group
=
group
)
current_comm_event
.
self_cuda_time
=
buffer
.
item
()
self
.
total_count
+=
current_comm_event
.
self_count
self
.
total_comm_vol
+=
current_comm_event
.
self_comm_vol
self
.
total_cuda_time
+=
current_comm_event
.
self_cuda_time
if
code_location
in
self
.
ops_record
:
self
.
ops_record
[
code_location
].
add
(
current_comm_event
)
else
:
self
.
ops_record
[
code_location
]
=
current_comm_event
self
.
profiler
=
None
self
.
pending_op
=
None
self
.
pending_metadata
=
None
def
wait_async_op
(
self
):
if
self
.
pending_op
is
not
None
:
op
=
self
.
pending_op
op
.
wait
()
self
.
close_profiler
()
class
CommHandler
(
object
):
"""Communication handler. A dummy handler to wait aync operations.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
prof
=
COL_COMM_PROF
def
wait
(
self
):
self
.
prof
.
wait_async_op
()
COL_COMM_PROF
=
CommProfiler
()
torch_all_reduce
=
dist
.
all_reduce
torch_all_gather
=
dist
.
all_gather
torch_reduce_scatter
=
dist
.
reduce_scatter
torch_broadcast
=
dist
.
broadcast
torch_reduce
=
dist
.
reduce
def
enable_communication_prof
(
depth
:
int
=
0
):
COL_COMM_PROF
.
depth
=
3
+
depth
dist
.
all_reduce
=
all_reduce
dist
.
all_gather
=
all_gather
dist
.
reduce_scatter
=
reduce_scatter
dist
.
broadcast
=
broadcast
dist
.
reduce
=
reduce
def
communication_prof_show
():
COL_COMM_PROF
.
show
()
def
async_check
():
if
COL_COMM_PROF
.
pending_op
is
not
None
:
COL_COMM_PROF
.
warn_flag
=
True
COL_COMM_PROF
.
wait_async_op
()
def
all_reduce
(
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
comm_size
=
dist
.
get_world_size
(
group
)
correction
=
2
*
(
comm_size
-
1
)
/
comm_size
comm_vol
=
correction
*
tensor
.
element_size
()
*
tensor
.
numel
()
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_AllReduce_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_all_reduce
(
tensor
,
op
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
COL_COMM_PROF
.
close_profiler
(
group
)
def
reduce_scatter
(
output
:
torch
.
Tensor
,
input_list
:
List
[
torch
.
Tensor
],
op
:
ReduceOp
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
comm_size
=
dist
.
get_world_size
(
group
)
correction
=
(
comm_size
-
1
)
/
comm_size
comm_vol
=
0
for
tensor
in
input_list
:
comm_vol
+=
tensor
.
element_size
()
*
tensor
.
numel
()
comm_vol
*=
correction
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_ReduceScatter_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_reduce_scatter
(
output
,
input_list
,
op
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
COL_COMM_PROF
.
close_profiler
(
group
)
def
all_gather
(
tensor_list
:
List
[
torch
.
Tensor
],
tensor
:
torch
.
Tensor
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
comm_size
=
dist
.
get_world_size
(
group
)
correction
=
(
comm_size
-
1
)
/
comm_size
comm_vol
=
0
for
ten
in
tensor_list
:
comm_vol
+=
ten
.
element_size
()
*
ten
.
numel
()
comm_vol
*=
correction
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_AllGather_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_all_gather
(
tensor_list
,
tensor
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
COL_COMM_PROF
.
close_profiler
(
group
)
def
broadcast
(
tensor
:
torch
.
Tensor
,
src
:
int
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
comm_vol
=
1.0
*
tensor
.
element_size
()
*
tensor
.
numel
()
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_Broadcast_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_broadcast
(
tensor
,
src
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
COL_COMM_PROF
.
close_profiler
(
group
)
def
reduce
(
tensor
:
torch
.
Tensor
,
dst
:
int
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
)
->
Optional
[
CommHandler
]:
async_check
()
comm_vol
=
1.0
*
tensor
.
element_size
()
*
tensor
.
numel
()
COL_COMM_PROF
.
activate_profiler
(
"ncclKernel_Reduce_"
,
comm_vol
)
COL_COMM_PROF
.
pending_op
=
torch_reduce
(
tensor
,
dst
,
op
,
group
,
async_op
)
if
async_op
:
return
CommHandler
()
COL_COMM_PROF
.
close_profiler
(
group
)
tests/test_profiler/test_comm_prof.py
0 → 100644
View file @
73bff112
from
functools
import
partial
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
colossalai
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.profiler
import
enable_communication_prof
,
communication_prof_show
BATCH_SIZE
=
1024
D_MODEL
=
1024
CONFIG
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
'1d'
,
size
=
4
)))
def
run_test
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
inputs
=
torch
.
randn
(
BATCH_SIZE
,
D_MODEL
,
dtype
=
torch
.
float32
,
device
=
get_current_device
())
outputs
=
torch
.
empty
(
world_size
,
BATCH_SIZE
,
D_MODEL
,
dtype
=
torch
.
float32
,
device
=
get_current_device
())
outputs_list
=
list
(
torch
.
chunk
(
outputs
,
chunks
=
world_size
,
dim
=
0
))
enable_communication_prof
()
op
=
dist
.
all_reduce
(
inputs
,
async_op
=
True
)
dist
.
all_gather
(
outputs_list
,
inputs
)
op
.
wait
()
dist
.
reduce_scatter
(
inputs
,
outputs_list
)
dist
.
broadcast
(
inputs
,
0
)
dist
.
reduce
(
inputs
,
0
)
if
rank
==
0
:
communication_prof_show
()
def
test_cc_prof
():
world_size
=
4
run_func
=
partial
(
run_test
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_cc_prof
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment