Commit 60000f73 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add halo correction using new cudnn masking feature

parent 9c16d945
...@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module): ...@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
if spatial_group_size > 1: if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1 stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2 stream2 = spatial_halo_exchanger.stream2
...@@ -271,56 +271,74 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -271,56 +271,74 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if explicit_nhwc: if explicit_nhwc:
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
else: else:
btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:]) btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0: if spatial_group_rank > 0:
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if explicit_nhwc: if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo) top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else: else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo) top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:]) top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args) top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
if spatial_group_size <= 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 2: elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x # wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3) torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad) fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
else: elif spatial_method == 3:
assert(False), "spatial_method must be 1 or 2" fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
if spatial_group_size <= 1 or spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
# compute halo cells for outputs[1] (out2) # compute halo cells for outputs[1] (out2)
if spatial_group_size > 1 and spatial_method == 1: if spatial_group_size > 1:
out2 = outputs[1] out2 = outputs[1]
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
if explicit_nhwc: if explicit_nhwc:
out2[:,:1,:,:].copy_(top_out2) top_out2_halo = out2[:,:1,:,:]
btm_out2_halo = out2[:,Hs-1:,:,:]
else: else:
out2[:,:,:1,:].copy_(top_out2) top_out2_halo = out2[:,:,:1,:]
btm_out2_halo = out2[:,:,Hs-1:,:]
if spatial_method == 1:
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2) torch.cuda.current_stream().wait_stream(stream2)
if explicit_nhwc: btm_out2_halo.copy_(btm_out2)
out2[:,Hs-1:,:,:].copy_(btm_out2) elif spatial_method == 3:
else: if spatial_group_rank > 0:
out2[:,:,Hs-1:,:].copy_(btm_out2) w1by3 = args[2][:,:,2:3,:].contiguous(memory_format=torch.preserve)
torch.cuda.current_stream().wait_stream(stream3) top_out1_halo = top_out1_halo.contiguous(memory_format=memory_format)
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.contiguous(memory_format=memory_format))
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
w1by3 = args[2][:,:,:1,:].contiguous(memory_format=torch.preserve)
btm_out1_halo = btm_out1_halo.contiguous(memory_format=memory_format)
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.contiguous(memory_format=memory_format))
btm_out2_halo.copy_(btm_out2)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
if spatial_method != 2:
torch.cuda.current_stream().wait_stream(stream3) torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,])) ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
...@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream) torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply spatial_bottleneck_function = SpatialBottleneckFunction.apply
...@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module):
for w in self.w_conv: for w in self.w_conv:
kaiming_uniform_(w, a=1) kaiming_uniform_(w, a=1)
self.thresholdTop, self.thresholdBottom = None, None
# TODO: prevent unsupported case usage # TODO: prevent unsupported case usage
# support cases # support cases
# native cudnn # native cudnn
...@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module):
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N,H,W,C = list(x.shape)
else:
N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
# calculate scale/bias from registered buffers # calculate scale/bias from registered buffers
# TODO: make this better # TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
...@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment