Commit 88914a50 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add halo correction kernel for bprop

parent 705aa35d
...@@ -268,17 +268,6 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -268,17 +268,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1: if spatial_method == 1:
# overlap mid convolution with halo transfer # overlap mid convolution with halo transfer
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
...@@ -291,6 +280,17 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -291,6 +280,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:]) btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3: elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3" assert(False), "spatial_method must be 1, 2 or 3"
...@@ -329,13 +329,6 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -329,13 +329,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# to wait for out2_mask to finish, but itself has to finish before # to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch. # the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels. # At least we can overlap the two halo correction kernels.
if spatial_group_rank > 0:
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
...@@ -344,9 +337,16 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -344,9 +337,16 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone()) btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2) btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0: if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1) stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2) torch.cuda.current_stream().wait_stream(stream2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
...@@ -365,6 +365,8 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -365,6 +365,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx.spatial_group_rank = spatial_group_rank ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method ctx.spatial_method = spatial_method
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1 ctx.stream1 = stream1
ctx.stream2 = stream2 ctx.stream2 = stream2
ctx.stream3 = stream3 ctx.stream3 = stream3
...@@ -414,50 +416,55 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -414,50 +416,55 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:]) top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
# copy halos to send buffer # copy halos to send buffer
if ctx.spatial_group_rank < ctx.spatial_group_size-1: if ctx.spatial_method == 1 or ctx.spatial_method == 2:
ctx.stream2.wait_stream(ctx.stream1) # 1 -> halo recompute approach
with torch.cuda.stream(ctx.stream2): # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
if ctx.explicit_nhwc: if ctx.spatial_group_rank < ctx.spatial_group_size-1:
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) ctx.stream2.wait_stream(ctx.stream1)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) with torch.cuda.stream(ctx.stream2):
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) if ctx.explicit_nhwc:
btm_fat_halo[:,2:,:,:].copy_(btm_halo) btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_() btm_fat_halo[:,2:,:,:].copy_(btm_halo)
else: btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_relu_halo[:,2:,:,:].zero_()
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:]) else:
btm_fat_halo[:,:,2:,:].copy_(btm_halo) btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:]) btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
btm_relu_halo[:,:,2:,:].zero_() btm_fat_halo[:,:,2:,:].copy_(btm_halo)
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo) btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc: btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:] btm_fat_relu_halo[:,:,2:,:].zero_()
else: btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:] if ctx.explicit_nhwc:
if ctx.spatial_group_rank > 0: btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
with torch.cuda.stream(ctx.stream1): else:
if ctx.explicit_nhwc: btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) if ctx.spatial_group_rank > 0:
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) with torch.cuda.stream(ctx.stream1):
top_fat_halo[:,:1,:,:].copy_(top_halo) if ctx.explicit_nhwc:
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo[:,:1,:,:].zero_() top_fat_halo[:,:1,:,:].copy_(top_halo)
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
else: top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) top_fat_relu_halo[:,:1,:,:].zero_()
top_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_fat_halo[:,:,:1,:].copy_(top_halo) else:
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:]) top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo[:,:,:1,:].zero_() top_fat_halo[:,:,:1,:].copy_(top_halo)
top_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:]) top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo) top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc: top_fat_relu_halo[:,:,:1,:].zero_()
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
else: top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:] if ctx.explicit_nhwc:
inc.add_delay(10) top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
with torch.cuda.stream(wgrad2_stream): with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -466,7 +473,10 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -466,7 +473,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
# apply halo cells to grad_out1 # apply halo cells to grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -474,20 +484,51 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -474,20 +484,51 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z = t_list[4] z = t_list[4]
relu1 = t_list[12] relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_group_rank > 0: if ctx.spatial_method == 1 or ctx.spatial_method == 2:
torch.cuda.current_stream().wait_stream(ctx.stream1) if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc: torch.cuda.current_stream().wait_stream(ctx.stream2)
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) if ctx.explicit_nhwc:
else: grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo) else:
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1: #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
torch.cuda.current_stream().wait_stream(ctx.stream2) if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc: torch.cuda.current_stream().wait_stream(ctx.stream1)
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) if ctx.explicit_nhwc:
else: grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo) else:
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
elif ctx.spatial_method == 3:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc:
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
else:
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
w1by3 = w[:,:1,:,:].clone()
ctx.stream1.wait_stream(ctx.stream2) # wait for halo transfers to finish
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc:
top_relu_halo = relu1[:,:1,:,:].clone()
top_grad_out1 = grad_out1[:,:1,:,:]
else:
top_relu_halo = relu1[:,:,:1,:].clone()
top_grad_out1 = grad_out1[:,:,:1,:]
w1by3 = w[:,2:,:,:].clone()
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
top_grad_out1.copy_(top_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream) torch.cuda.current_stream().wait_stream(wgrad2_stream)
......
...@@ -161,7 +161,7 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3 ...@@ -161,7 +161,7 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank = rank spatial_group_rank = rank
spatial_communicator = None spatial_communicator = None
spatial_halo_exchanger = halex spatial_halo_exchanger = halex
spatial_method = 2 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x spatial_method = 3 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method) spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args) spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment