"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "0cd81fa85c0ffe3f44a21ecc7e7bdd5a15dbdabf"
Commit 705aa35d authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Fix halo correction kernel

parent 60000f73
...@@ -268,6 +268,17 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -268,6 +268,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1: if spatial_method == 1:
# overlap mid convolution with halo transfer # overlap mid convolution with halo transfer
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
...@@ -280,17 +291,6 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -280,17 +291,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:]) btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3: elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3" assert(False), "spatial_method must be 1, 2 or 3"
...@@ -324,21 +324,35 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -324,21 +324,35 @@ class SpatialBottleneckFunction(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(stream2) torch.cuda.current_stream().wait_stream(stream2)
btm_out2_halo.copy_(btm_out2) btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3: elif spatial_method == 3:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if spatial_group_rank > 0: if spatial_group_rank > 0:
w1by3 = args[2][:,:,2:3,:].contiguous(memory_format=torch.preserve) stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
top_out1_halo = top_out1_halo.contiguous(memory_format=memory_format) with torch.cuda.stream(stream1):
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.contiguous(memory_format=memory_format)) w1by3 = args[2][:,:1,:,:].clone()
top_out2_halo.copy_(top_out2) top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
w1by3 = args[2][:,:,:1,:].contiguous(memory_format=torch.preserve) stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
btm_out1_halo = btm_out1_halo.contiguous(memory_format=memory_format) with torch.cuda.stream(stream2):
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.contiguous(memory_format=memory_format)) w1by3 = args[2][:,2:3,:,:].clone()
btm_out2_halo.copy_(btm_out2) btm_out1_halo = btm_out1_halo.clone()
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
if spatial_method != 2: if spatial_method != 2:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch.cuda.current_stream().wait_stream(stream3) torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,])) ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment