"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "b760c5694df723227dc016e7f35c0fb66955e0d3"
Commit 73bff112 authored by 1SAA's avatar 1SAA Committed by Frank Lee
Browse files

Added profiler communication operations

Fixed bug for learning rate scheduler
parent d275b98b
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
from .p2p import (send_forward, send_forward_recv_forward,
send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward,
recv_backward)
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [
'all_gather', 'reduce_scatter', 'all_reduce', 'broadcast', 'reduce',
'send_forward', 'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward', 'send_backward',
'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta',
'all_gather',
'reduce_scatter',
'all_reduce',
'broadcast',
'reduce',
'send_forward',
'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward',
'send_backward',
'send_backward_recv_backward',
'send_backward_recv_forward',
'send_forward_recv_backward',
'recv_backward',
'recv_forward',
'ring_forward',
'send_tensor_meta',
'recv_tensor_meta',
]
......@@ -29,6 +29,7 @@ class LRSchedulerHook(MetricHook):
self.store_lr_in_state = store_lr_in_state
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
......
from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context,
copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter,
is_moe_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp,
is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0,
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_moe_parallel_parameter,
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence,
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0,
switch_virtual_pipeline_parallel_rank, sync_model_param)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
......
from .comm_profiler import enable_communication_prof, communication_prof_show
import inspect
import torch
from torch.autograd.profiler import profile
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.utils import get_current_device
from typing import List, Optional
def _get_code_location(depth: int):
ret = ""
length = len(inspect.stack())
for i in range(3, min(length, depth + 1)):
upper_frame = inspect.stack()[i]
function_name = inspect.stack()[i - 1].function
info = upper_frame.filename + "(" + str(upper_frame.lineno) + "): " + function_name + "\n"
ret += info
return ret
# copied from high version pytorch to support low version
def _format_time(time_us):
"""Defines how to format time in FunctionEvent"""
US_IN_SECOND = 1000.0 * 1000.0
US_IN_MS = 1000.0
if time_us >= US_IN_SECOND:
return '{:.3f}s'.format(time_us / US_IN_SECOND)
if time_us >= US_IN_MS:
return '{:.3f}ms'.format(time_us / US_IN_MS)
return '{:.3f}us'.format(time_us)
# copied from high version pytorch to support low version
def _format_memory(nbytes):
"""Returns a formatted memory size string"""
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
if (abs(nbytes) >= GB):
return '{:.2f} Gb'.format(nbytes * 1.0 / GB)
elif (abs(nbytes) >= MB):
return '{:.2f} Mb'.format(nbytes * 1.0 / MB)
elif (abs(nbytes) >= KB):
return '{:.2f} Kb'.format(nbytes * 1.0 / KB)
else:
return str(nbytes) + ' b'
def _format_bandwith(volme: float, time_us: int):
sec_div_mb = (1000.0 / 1024.0)**2
mb_per_sec = volme / time_us * sec_div_mb
if mb_per_sec >= 1024.0:
return '{:.3f} Gb/s'.format(mb_per_sec / 1024.0)
else:
return '{:.3f} Mb/s'.format(mb_per_sec)
class CommEvent(object):
"""Communication Event. Used for communication time and communication
volume recording.
"""
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
self.self_count = count
self.self_comm_vol = comm_vol
self.self_cuda_time = cuda_time
def add(self, rhs):
self.self_count += rhs.self_count
self.self_comm_vol += rhs.self_comm_vol
self.self_cuda_time += rhs.self_cuda_time
class CommProfiler(object):
"""Communication profiler. Records all communication events.
"""
def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3):
super().__init__()
self.total_count = total_count
self.total_comm_vol = total_comm_vol
self.total_cuda_time = total_cuda_time
self.depth = prof_depth
self.ops_record = dict()
self.profiler = None
self.pending_op = None
self.pending_metadata = None
self.warn_flag = False
def reset(self):
self.total_count = 0
self.total_comm_vol = 0
self.total_cuda_time = 0
self.ops_record = dict()
self.profiler = None
self.pending_op = None
self.pending_metadata = None
self.warn_flag = False
def show(self):
if self.warn_flag:
print("Warnning: there exists multiple communication operations in the same time.\n"
"As a result, the profiling result is not accurate.")
print("Collective communication profiling result:",
"total cuda time: {}".format(_format_time(self.total_cuda_time)),
"average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)),
"total number of calls: {}".format(self.total_count),
"All events:",
sep='\n')
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list:
print(location,
"self cuda time: {}".format(_format_time(event.self_cuda_time)),
"{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0),
"self communication volme: {}".format(_format_memory(event.self_comm_vol)),
"average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)),
"number of calls: {}".format(event.self_count),
"--------------------",
sep='\n')
@property
def has_aync_op(self):
return self.pending_op is not None
def activate_profiler(self, kn: str, vol: float):
self.pending_metadata = (kn, _get_code_location(self.depth), vol)
self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
self.profiler.__enter__()
def close_profiler(self, group=None):
assert self.profiler is not None, "There is no running dist op"
kernel_name, code_location, vol = self.pending_metadata
self.profiler.__exit__(None, None, None)
if self.profiler.enabled:
assert_flag = 0
current_comm_event = None
events = self.profiler.function_events
for event in events:
if kernel_name in event.name:
assert assert_flag == 0, "Multiple dist ops has been called "
current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
assert_flag += 1
assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item()
self.total_count += current_comm_event.self_count
self.total_comm_vol += current_comm_event.self_comm_vol
self.total_cuda_time += current_comm_event.self_cuda_time
if code_location in self.ops_record:
self.ops_record[code_location].add(current_comm_event)
else:
self.ops_record[code_location] = current_comm_event
self.profiler = None
self.pending_op = None
self.pending_metadata = None
def wait_async_op(self):
if self.pending_op is not None:
op = self.pending_op
op.wait()
self.close_profiler()
class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations.
"""
def __init__(self):
super().__init__()
self.prof = COL_COMM_PROF
def wait(self):
self.prof.wait_async_op()
COL_COMM_PROF = CommProfiler()
torch_all_reduce = dist.all_reduce
torch_all_gather = dist.all_gather
torch_reduce_scatter = dist.reduce_scatter
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
def enable_communication_prof(depth: int = 0):
COL_COMM_PROF.depth = 3 + depth
dist.all_reduce = all_reduce
dist.all_gather = all_gather
dist.reduce_scatter = reduce_scatter
dist.broadcast = broadcast
dist.reduce = reduce
def communication_prof_show():
COL_COMM_PROF.show()
def async_check():
if COL_COMM_PROF.pending_op is not None:
COL_COMM_PROF.warn_flag = True
COL_COMM_PROF.wait_async_op()
def all_reduce(tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
comm_size = dist.get_world_size(group)
correction = 2 * (comm_size - 1) / comm_size
comm_vol = correction * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol)
COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op)
if async_op:
return CommHandler()
COL_COMM_PROF.close_profiler(group)
def reduce_scatter(output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
comm_vol = 0
for tensor in input_list:
comm_vol += tensor.element_size() * tensor.numel()
comm_vol *= correction
COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
if async_op:
return CommHandler()
COL_COMM_PROF.close_profiler(group)
def all_gather(tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
comm_vol = 0
for ten in tensor_list:
comm_vol += ten.element_size() * ten.numel()
comm_vol *= correction
COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol)
COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
if async_op:
return CommHandler()
COL_COMM_PROF.close_profiler(group)
def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]:
async_check()
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol)
COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op)
if async_op:
return CommHandler()
COL_COMM_PROF.close_profiler(group)
def reduce(tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol)
COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op)
if async_op:
return CommHandler()
COL_COMM_PROF.close_profiler(group)
from functools import partial
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.utils.profiler import enable_communication_prof, communication_prof_show
BATCH_SIZE = 1024
D_MODEL = 1024
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4)))
def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
enable_communication_prof()
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
if rank == 0:
communication_prof_show()
def test_cc_prof():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_cc_prof()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment