Commit 34df0f79 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

wgrad2 in parallel stream, optional mode to wait for halo transfer

parent 834b1d01
...@@ -253,31 +253,35 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -253,31 +253,35 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_out1_halo = out1_pad[:,:1,:,:] top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:] btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
if spatial_group_rank < spatial_group_size-1: if spatial_method == 1:
stream2.wait_stream(stream1) # overlap mid convolution with halo transfer
with torch.cuda.stream(stream2): if spatial_group_rank < spatial_group_size-1:
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) stream2.wait_stream(stream1)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) with torch.cuda.stream(stream2):
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
else: btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
with torch.cuda.stream(stream2): btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
btm_out1_halo.zero_() if spatial_group_rank > 0:
if spatial_group_rank > 0: with torch.cuda.stream(stream1):
with torch.cuda.stream(stream1): top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args) inc.add_delay(10)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(nhwc, stride_1x1, args, outputs, out1_pad)
else: else:
with torch.cuda.stream(stream1): assert(False), "spatial_method must be 1 or 2"
top_out1_halo.zero_()
inc.add_delay(10)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs) if spatial_group_size <= 1 or spatial_method == 1:
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# compute halo cells for outputs[1] (out2) # compute halo cells for outputs[1] (out2)
if spatial_group_size > 1: if spatial_group_size > 1 and spatial_method == 1:
out2 = outputs[1] out2 = outputs[1]
if spatial_group_rank > 0: if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
...@@ -290,6 +294,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -290,6 +294,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,])) ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
ctx.save_for_backward(*(args+outputs)) ctx.save_for_backward(*(args+outputs))
...@@ -368,10 +373,13 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -368,10 +373,13 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
inc.add_delay(10) inc.add_delay(10)
if ctx.spatial_group_size > 1: wgrad2_stream = torch.cuda.Stream()
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2) wgrad2_stream.wait_stream(torch.cuda.current_stream())
else: with torch.cuda.stream(wgrad2_stream):
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells # compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) #wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
...@@ -406,6 +414,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -406,6 +414,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, *grads)
......
...@@ -161,8 +161,8 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3 ...@@ -161,8 +161,8 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank = rank spatial_group_rank = rank
spatial_communicator = None spatial_communicator = None
spatial_halo_exchanger = halex spatial_halo_exchanger = halex
spatial_stream = None # Not in use spatial_method = 2 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_stream) spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args) spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
with torch.no_grad(): with torch.no_grad():
...@@ -217,8 +217,14 @@ def main(): ...@@ -217,8 +217,14 @@ def main():
peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024) peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator) #halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator) halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True) bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)
compare(gt, bt2) compare(gt, bt2)
......
...@@ -9,14 +9,17 @@ import peer_memory as pm ...@@ -9,14 +9,17 @@ import peer_memory as pm
# NB! This is only useful for performance testing. # NB! This is only useful for performance testing.
# NB! Do not use for actual production runs # NB! Do not use for actual production runs
class HaloExchanger(object): class HaloExchanger(object):
def __init__(self): def __init__(self, spatial_group_size, rank):
self.stream1 = torch.cuda.Stream() self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream() self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream() self.stream3 = torch.cuda.Stream()
spatial_rank = rank % spatial_group_size
self.left_zero = True if spatial_rank == 0 else False
self.right_zero = True if spatial_rank == spatial_group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger): class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerNoComm, self).__init__() super(HaloExchangerNoComm, self).__init__(spatial_group_size, rank)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None: if left_input_halo is None:
...@@ -27,7 +30,7 @@ class HaloExchangerNoComm(HaloExchanger): ...@@ -27,7 +30,7 @@ class HaloExchangerNoComm(HaloExchanger):
class HaloExchangerAllGather(HaloExchanger): class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerAllGather, self).__init__() super(HaloExchangerAllGather, self).__init__(spatial_group_size, rank)
self.spatial_group_size = spatial_group_size self.spatial_group_size = spatial_group_size
self.local_rank = rank % spatial_group_size self.local_rank = rank % spatial_group_size
self.comm = comm self.comm = comm
...@@ -43,14 +46,24 @@ class HaloExchangerAllGather(HaloExchanger): ...@@ -43,14 +46,24 @@ class HaloExchangerAllGather(HaloExchanger):
ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:] ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:] ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
if left_input_halo is None: if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo return ag_left_input_halo, ag_right_input_halo
else: else:
left_input_halo.copy_(ag_left_input_halo) if self.left_zero:
right_input_halo.copy_(ag_right_input_halo) left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger): class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerSendRecv, self).__init__() super(HaloExchangerSendRecv, self).__init__(spatial_group_size, rank)
self.world_size = world_size self.world_size = world_size
self.spatial_group_size = spatial_group_size self.spatial_group_size = spatial_group_size
nccl_id = inc.get_unique_nccl_id(1).cuda() nccl_id = inc.get_unique_nccl_id(1).cuda()
...@@ -60,14 +73,14 @@ class HaloExchangerSendRecv(HaloExchanger): ...@@ -60,14 +73,14 @@ class HaloExchangerSendRecv(HaloExchanger):
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None: if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, left_output_halo, right_output_halo, self.spatial_group_size) left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo return left_input_halo, right_input_halo
else: else:
inc.left_right_halo_exchange_inplace(self.handle, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size) inc.left_right_halo_exchange_inplace(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size)
class HaloExchangerPeer(HaloExchanger): class HaloExchangerPeer(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1): def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__() super(HaloExchangerPeer, self).__init__(spatial_group_size, rank)
self.diagnostics = False self.diagnostics = False
self.spatial_group_size = spatial_group_size self.spatial_group_size = spatial_group_size
self.peer_rank = rank % spatial_group_size self.peer_rank = rank % spatial_group_size
...@@ -93,6 +106,11 @@ class HaloExchangerPeer(HaloExchanger): ...@@ -93,6 +106,11 @@ class HaloExchangerPeer(HaloExchanger):
right_output_halo, right_tx[self.peer_rank], left_tx[self.right_neighbor], right_input_halo, right_output_halo, right_tx[self.peer_rank], left_tx[self.right_neighbor], right_input_halo,
self.signals[self.left_neighbor], self.signals[self.right_neighbor], self.signals[self.peer_rank] self.signals[self.left_neighbor], self.signals[self.right_neighbor], self.signals[self.peer_rank]
) )
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace: if not inplace:
return left_input_halo, right_input_halo return left_input_halo, right_input_halo
......
...@@ -1620,6 +1620,7 @@ struct bottleneck_forward_status { ...@@ -1620,6 +1620,7 @@ struct bottleneck_forward_status {
int64_t outdimA0[4]; int64_t outdimA0[4];
int64_t outdimA1[4]; int64_t outdimA1[4];
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; int64_t outdimA2[4];
int64_t outdimA3[4]; int64_t outdimA3[4];
int64_t outdimA4[4]; int64_t outdimA4[4];
...@@ -1633,6 +1634,7 @@ struct bottleneck_forward_status { ...@@ -1633,6 +1634,7 @@ struct bottleneck_forward_status {
int64_t outdim0[4]; // halo input shape int64_t outdim0[4]; // halo input shape
int64_t outdim1[4]; int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4]; int64_t outdim2[4];
int64_t outdim3[4]; int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape int64_t outdim4[4]; // halo output shape
...@@ -1672,6 +1674,7 @@ struct bottleneck_forward_status { ...@@ -1672,6 +1674,7 @@ struct bottleneck_forward_status {
// output dim in n,c,h,w used by backend // output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0; outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
...@@ -1690,6 +1693,13 @@ struct bottleneck_forward_status { ...@@ -1690,6 +1693,13 @@ struct bottleneck_forward_status {
for (int dim = 0; dim < 2; dim++) { for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
} }
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0]; outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0]; outdimA2[1] = filterdimA2[0];
...@@ -1715,6 +1725,7 @@ struct bottleneck_forward_status { ...@@ -1715,6 +1725,7 @@ struct bottleneck_forward_status {
// Create output tensor in the correct shape in pytorch's view // Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -1726,6 +1737,7 @@ struct bottleneck_forward_status { ...@@ -1726,6 +1737,7 @@ struct bottleneck_forward_status {
for (int dim=0;dim<4;dim++) { for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]]; outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]]; outdim4[dim] = outdimA4[axis[dim]];
...@@ -1859,6 +1871,44 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at: ...@@ -1859,6 +1871,44 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
} }
void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor out1_pad) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1_pad.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1b,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed; std::cout << std::fixed;
...@@ -2520,6 +2570,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -2520,6 +2570,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward"); m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward"); m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward"); m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward"); m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init"); m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
......
...@@ -100,7 +100,7 @@ class NcclCommWrapper ...@@ -100,7 +100,7 @@ class NcclCommWrapper
}); });
} }
void left_right_halo_exchange_inplace(at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size) void left_right_halo_exchange_inplace(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{ {
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart(); ncclGroupStart();
...@@ -132,16 +132,18 @@ class NcclCommWrapper ...@@ -132,16 +132,18 @@ class NcclCommWrapper
}); });
} }
ncclGroupEnd(); ncclGroupEnd();
if (left_zero) left_input_halo.zero_();
if (right_zero) right_input_halo.zero_();
} }
std::vector<at::Tensor> left_right_halo_exchange(at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size) std::vector<at::Tensor> left_right_halo_exchange(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{ {
// after halo exchange: // after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank // left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank // right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo); auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo); auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size); left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
return {left_input_halo, right_input_halo}; return {left_input_halo, right_input_halo};
} }
}; };
...@@ -195,18 +197,18 @@ void nccl_recv(int handle, at::Tensor input, int sender) ...@@ -195,18 +197,18 @@ void nccl_recv(int handle, at::Tensor input, int sender)
communicator.recv(input, sender); communicator.recv(input, sender);
} }
void left_right_halo_exchange_inplace(int handle, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size) void left_right_halo_exchange_inplace(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{ {
assert(handle >= 0 && handle < nccl_comms.size()); assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle]; class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange_inplace(left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size); return communicator.left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
} }
std::vector<at::Tensor> left_right_halo_exchange(int handle, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size) std::vector<at::Tensor> left_right_halo_exchange(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{ {
assert(handle >= 0 && handle < nccl_comms.size()); assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle]; class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange(left_output_halo, right_output_halo, group_size); return communicator.left_right_halo_exchange(left_zero, right_zero, left_output_halo, right_output_halo, group_size);
} }
void add_delay(int delay) void add_delay(int delay)
......
...@@ -38,6 +38,8 @@ void nccl_recv( ...@@ -38,6 +38,8 @@ void nccl_recv(
); );
void left_right_halo_exchange_inplace( void left_right_halo_exchange_inplace(
int handle, int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo, at::Tensor left_output_halo,
at::Tensor right_output_halo, at::Tensor right_output_halo,
at::Tensor left_input_halo, at::Tensor left_input_halo,
...@@ -45,6 +47,8 @@ void left_right_halo_exchange_inplace( ...@@ -45,6 +47,8 @@ void left_right_halo_exchange_inplace(
int group_size); int group_size);
std::vector<at::Tensor> left_right_halo_exchange( std::vector<at::Tensor> left_right_halo_exchange(
int handle, int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo, at::Tensor left_output_halo,
at::Tensor right_output_halo, at::Tensor right_output_halo,
int group_size int group_size
......
...@@ -239,6 +239,7 @@ int64_t allocate_raw(int64_t size) ...@@ -239,6 +239,7 @@ int64_t allocate_raw(int64_t size)
{ {
float* ptr = 0L; float* ptr = 0L;
cudaMalloc(&ptr, size); cudaMalloc(&ptr, size);
cudaMemset(ptr, 0, size);
return (int64_t)ptr; return (int64_t)ptr;
} }
......
...@@ -53,7 +53,7 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli ...@@ -53,7 +53,7 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_inp_halo.copy_(top_inp_halos[btm_rank]) btm_inp_halo.copy_(top_inp_halos[btm_rank])
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, numSM=1): def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
if memory_format == 1: if memory_format == 1:
# 1 -> explicit nhwc # 1 -> explicit nhwc
explicit_nhwc = True explicit_nhwc = True
...@@ -77,10 +77,23 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, ...@@ -77,10 +77,23 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
if memory_format == 2: if memory_format == 2:
y = y.to(memory_format=torch.channels_last) y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo] ym = y[:,:,:,half_halo:W+half_halo]
y2 = y.clone() y3 = y.clone()
halo_ex(y, H_split, explicit_nhwc, numSM) list_y = []
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split) for step in range(num_steps):
is_equal = torch.all(torch.eq(y,y2)) halo_ex(y, H_split, explicit_nhwc, numSM)
list_y.append(y.clone())
y.copy_(y3)
halo_ex.peer_pool.reset()
torch.distributed.barrier()
y2 = y3.clone()
list_y2 = []
for step in range(num_steps):
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
list_y2.append(y2.clone())
y2.copy_(y3)
is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)]
is_equal = torch.tensor(is_equal, dtype=torch.bool)
is_equal = torch.all(is_equal)
if peer_rank == 0: if peer_rank == 0:
if memory_format == 1: if memory_format == 1:
memory_format_str = "explicit_nhwc" memory_format_str = "explicit_nhwc"
...@@ -99,26 +112,26 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, ...@@ -99,26 +112,26 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
torch.distributed.barrier() torch.distributed.barrier()
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex): def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Hr = 8*world_size Hr = 8*world_size
Hp = ((H + Hr - 1) // Hr) * 8 Hp = ((H + Hr - 1) // Hr) * 8
for i in range(4): for i in range(4):
div = int(pow(2,i)) div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True) single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True) single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True) single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps)
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex): def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Wr = 8*world_size Wr = 8*world_size
Wp = ((W + Wr - 1) // Wr) * 8 Wp = ((W + Wr - 1) // Wr) * 8
for i in range(4): for i in range(4):
div = int(pow(2,i)) div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False) single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False) single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False) single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps)
def main(): def main():
...@@ -130,11 +143,13 @@ def main(): ...@@ -130,11 +143,13 @@ def main():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024) pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
num_steps = 100
half_halo = 1 half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo) halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex) H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex) W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment