You need to sign in or sign up before continuing.
Commit 60000f73 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add halo correction using new cudnn masking feature

parent 9c16d945
...@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module): ...@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
if spatial_group_size > 1: if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1 stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2 stream2 = spatial_halo_exchanger.stream2
...@@ -271,57 +271,75 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -271,57 +271,75 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if explicit_nhwc: if explicit_nhwc:
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
else: else:
btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:]) btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0: if spatial_group_rank > 0:
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if explicit_nhwc: if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo) top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else: else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo) top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:]) top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args) top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method == 2: elif spatial_method != 2 and spatial_method != 3:
# wait for halo transfer to finish before doing a full convolution of padded x assert(False), "spatial_method must be 1, 2 or 3"
torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
else:
assert(False), "spatial_method must be 1 or 2"
if spatial_group_size <= 1 or spatial_method == 1: if spatial_group_size <= 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
# compute halo cells for outputs[1] (out2) # compute halo cells for outputs[1] (out2)
if spatial_group_size > 1 and spatial_method == 1: if spatial_group_size > 1:
out2 = outputs[1] out2 = outputs[1]
if spatial_group_rank > 0: if explicit_nhwc:
torch.cuda.current_stream().wait_stream(stream1) top_out2_halo = out2[:,:1,:,:]
if explicit_nhwc: btm_out2_halo = out2[:,Hs-1:,:,:]
out2[:,:1,:,:].copy_(top_out2) else:
else: top_out2_halo = out2[:,:,:1,:]
out2[:,:,:1,:].copy_(top_out2) btm_out2_halo = out2[:,:,Hs-1:,:]
if spatial_group_rank < spatial_group_size-1: if spatial_method == 1:
torch.cuda.current_stream().wait_stream(stream2) if spatial_group_rank > 0:
if explicit_nhwc: torch.cuda.current_stream().wait_stream(stream1)
out2[:,Hs-1:,:,:].copy_(btm_out2) top_out2_halo.copy_(top_out2)
else: if spatial_group_rank < spatial_group_size-1:
out2[:,:,Hs-1:,:].copy_(btm_out2) torch.cuda.current_stream().wait_stream(stream2)
torch.cuda.current_stream().wait_stream(stream3) btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3:
if spatial_group_rank > 0:
w1by3 = args[2][:,:,2:3,:].contiguous(memory_format=torch.preserve)
top_out1_halo = top_out1_halo.contiguous(memory_format=memory_format)
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.contiguous(memory_format=memory_format))
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
w1by3 = args[2][:,:,:1,:].contiguous(memory_format=torch.preserve)
btm_out1_halo = btm_out1_halo.contiguous(memory_format=memory_format)
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.contiguous(memory_format=memory_format))
btm_out2_halo.copy_(btm_out2)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
torch.cuda.current_stream().wait_stream(stream3) if spatial_method != 2:
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,])) ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
ctx.save_for_backward(*(args+outputs)) ctx.save_for_backward(*(args+outputs))
...@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream) torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply spatial_bottleneck_function = SpatialBottleneckFunction.apply
...@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module):
for w in self.w_conv: for w in self.w_conv:
kaiming_uniform_(w, a=1) kaiming_uniform_(w, a=1)
self.thresholdTop, self.thresholdBottom = None, None
# TODO: prevent unsupported case usage # TODO: prevent unsupported case usage
# support cases # support cases
# native cudnn # native cudnn
...@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module):
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N,H,W,C = list(x.shape)
else:
N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
# calculate scale/bias from registered buffers # calculate scale/bias from registered buffers
# TODO: make this better # TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
...@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment