Commit 0c20c455 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Some fixes to better support native nhwc

parent 34df0f79
......@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, nhwc, stride_1x1, scale, bias, x, *conv):
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
......@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args.append(scale[3])
args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
if spatial_group_size > 1:
out1 = outputs[0]
# TODO: This assumes explicit nhwc
N,Hs,W,C = list(out1.shape)
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
if explicit_nhwc:
N,Hs,W,C = list(out1.shape)
memory_format = torch.contiguous_format
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
else:
N,C,Hs,W = list(out1.shape)
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream())
stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3):
out1_pad[:,1:Hs+1,:,:].copy_(out1)
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
with torch.cuda.stream(stream1):
top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
if explicit_nhwc:
top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
else:
top_out1_halo = out1_pad[:,:,:1,:]
btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1:
# overlap mid convolution with halo transfer
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
if explicit_nhwc:
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
else:
btm_fat_halo[:,:,0:2,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
if explicit_nhwc:
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(nhwc, stride_1x1, args, outputs, out1_pad)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
else:
assert(False), "spatial_method must be 1 or 2"
if spatial_group_size <= 1 or spatial_method == 1:
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1 and spatial_method == 1:
out2 = outputs[1]
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
out2[:,:1,:,:].copy_(top_out2)
if explicit_nhwc:
out2[:,:1,:,:].copy_(top_out2)
else:
out2[:,:,:1,:].copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
out2[:,Hs-1:,:,:].copy_(btm_out2)
if explicit_nhwc:
out2[:,Hs-1:,:,:].copy_(btm_out2)
else:
out2[:,:,Hs-1:,:].copy_(btm_out2)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass
if spatial_group_size > 1:
torch.cuda.current_stream().wait_stream(stream3)
......@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
else:
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.explicit_nhwc = explicit_nhwc
ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size
if spatial_group_size > 1:
......@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
......@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with torch.cuda.stream(ctx.stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
if ctx.explicit_nhwc:
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_()
else:
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_halo)
btm_relu_halo[:,:,:2,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,:,2:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo)
if ctx.explicit_nhwc:
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
else:
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:1,:,:].zero_()
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
if ctx.explicit_nhwc:
top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:1,:,:].zero_()
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
else:
top_fat_halo[:,:,:1,:].copy_(top_halo)
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:,:1,:].zero_()
top_relu_halo[:,:,1:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo)
if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply halo cells to grad_out1
if ctx.spatial_group_size > 1:
......@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
if ctx.explicit_nhwc:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
else:
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2)
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
if ctx.explicit_nhwc:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
else:
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, *grads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment