Commit 9c16d945 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fixes

parent 0c20c455
...@@ -276,7 +276,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -276,7 +276,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
else: else:
btm_fat_halo[:,:,0:2,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0: if spatial_group_rank > 0:
...@@ -287,7 +287,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -287,7 +287,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else: else:
top_fat_halo[:,:,:1,:].copy_(top_out1_halo) top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args) top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method == 2: elif spatial_method == 2:
...@@ -368,10 +368,15 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -368,10 +368,15 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list) grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads) grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
# do halo exchange of grad_out2 here # do halo exchange of grad_out2 here
# compute halo cells for grad_out1 # compute halo cells for grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
N,Hs,W,C = list(grad_out2.shape) if ctx.explicit_nhwc:
N,Hs,W,C = list(grad_out2.shape)
else:
N,C,Hs,W = list(grad_out2.shape)
relu1 = t_list[12] relu1 = t_list[12]
ctx.stream1.wait_stream(torch.cuda.current_stream()) ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
...@@ -380,17 +385,19 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -380,17 +385,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if ctx.spatial_group_rank < ctx.spatial_group_size-1: if ctx.spatial_group_rank < ctx.spatial_group_size-1:
ctx.stream2.wait_stream(ctx.stream1) ctx.stream2.wait_stream(ctx.stream1)
with torch.cuda.stream(ctx.stream2): with torch.cuda.stream(ctx.stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc: if ctx.explicit_nhwc:
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_halo) btm_fat_halo[:,2:,:,:].copy_(btm_halo)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_() btm_relu_halo[:,2:,:,:].zero_()
else: else:
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,Hs-2:,:,:]) btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_halo) btm_fat_halo[:,:,2:,:].copy_(btm_halo)
btm_relu_halo[:,:,:2,:].copy_(relu1[:,Hs-2:,:,:]) btm_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
btm_relu_halo[:,:,2:,:].zero_() btm_relu_halo[:,:,2:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo) btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo)
if ctx.explicit_nhwc: if ctx.explicit_nhwc:
...@@ -399,18 +406,20 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -399,18 +406,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:] btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
if ctx.spatial_group_rank > 0: if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc: if ctx.explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:1,:,:].copy_(top_halo) top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:1,:,:].zero_() top_relu_halo[:,:1,:,:].zero_()
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
else: else:
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:,:1,:].copy_(top_halo) top_fat_halo[:,:,:1,:].copy_(top_halo)
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:2,:,:]) top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
top_relu_halo[:,:,:1,:].zero_() top_relu_halo[:,:,:1,:].zero_()
top_relu_halo[:,:,1:,:].copy_(relu1[:,:2,:,:]) top_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo) top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo)
if ctx.explicit_nhwc: if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
...@@ -418,28 +427,12 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -418,28 +427,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:] top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10) inc.add_delay(10)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(wgrad2_stream): with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else: else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment