Commit 4aeb24cb authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent c70f0e32
......@@ -257,7 +257,7 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
......@@ -328,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
#inc.add_delay(10)
if use_delay_kernel: inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
......@@ -416,6 +416,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.use_delay_kernel = use_delay_kernel
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1
......@@ -518,7 +519,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
#inc.add_delay(10)
if ctx.use_delay_kernel: inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
......@@ -583,7 +584,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, None, None, *grads)
return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply
......@@ -655,7 +656,7 @@ class SpatialBottleneck(torch.nn.Module):
# spatial communicator
if spatial_parallel_args is None:
self.spatial_parallel_args = (1, 0, None, None, 0)
self.spatial_parallel_args = (1, 0, None, None, 0, False)
else:
self.spatial_parallel_args = spatial_parallel_args
return
......@@ -684,7 +685,7 @@ class SpatialBottleneck(torch.nn.Module):
def forward(self, x):
if self.use_cudnn:
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _ = self.spatial_parallel_args
spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N,H,W,C = list(x.shape)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment