"src/vscode:/vscode.git/clone" did not exist on "8d36d5adb1edb8eaaa40a29ef5510f51c503f19e"
Commit 0c20c455 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Some fixes to better support native nhwc

parent 34df0f79
...@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module): ...@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, explicit_nhwc, stride_1x1, scale, bias, x, *conv):
if spatial_group_size > 1: if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1 stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2 stream2 = spatial_halo_exchanger.stream2
...@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args.append(scale[3]) args.append(scale[3])
args.append(bias[3]) args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
# here we pass in flag and let c++ handle it # here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in # alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args) outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
if spatial_group_size > 1: if spatial_group_size > 1:
out1 = outputs[0] out1 = outputs[0]
# TODO: This assumes explicit nhwc if explicit_nhwc:
N,Hs,W,C = list(out1.shape) N,Hs,W,C = list(out1.shape)
memory_format = torch.contiguous_format
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda') out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
else:
N,C,Hs,W = list(out1.shape)
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream()) stream1.wait_stream(torch.cuda.current_stream())
stream3.wait_stream(torch.cuda.current_stream()) stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3): with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1) out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
if explicit_nhwc:
top_out1_halo = out1_pad[:,:1,:,:] top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:] btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
else:
top_out1_halo = out1_pad[:,:,:1,:]
btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1: if spatial_method == 1:
# overlap mid convolution with halo transfer # overlap mid convolution with halo transfer
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if explicit_nhwc:
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args) else:
btm_fat_halo[:,:,0:2,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0: if spatial_group_rank > 0:
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if explicit_nhwc:
top_fat_halo[:,:1,:,:].copy_(top_out1_halo) top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args) else:
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method == 2: elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x # wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3) torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(nhwc, stride_1x1, args, outputs, out1_pad) fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
else: else:
assert(False), "spatial_method must be 1 or 2" assert(False), "spatial_method must be 1 or 2"
if spatial_group_size <= 1 or spatial_method == 1: if spatial_group_size <= 1 or spatial_method == 1:
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
# compute halo cells for outputs[1] (out2) # compute halo cells for outputs[1] (out2)
if spatial_group_size > 1 and spatial_method == 1: if spatial_group_size > 1 and spatial_method == 1:
out2 = outputs[1] out2 = outputs[1]
if spatial_group_rank > 0: if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
if explicit_nhwc:
out2[:,:1,:,:].copy_(top_out2) out2[:,:1,:,:].copy_(top_out2)
else:
out2[:,:,:1,:].copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2) torch.cuda.current_stream().wait_stream(stream2)
if explicit_nhwc:
out2[:,Hs-1:,:,:].copy_(btm_out2) out2[:,Hs-1:,:,:].copy_(btm_out2)
else:
out2[:,:,Hs-1:,:].copy_(btm_out2)
torch.cuda.current_stream().wait_stream(stream3) torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
torch.cuda.current_stream().wait_stream(stream3) torch.cuda.current_stream().wait_stream(stream3)
...@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
else: else:
ctx.save_for_backward(*(args+outputs)) ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu # save relu outputs for drelu
ctx.nhwc = nhwc ctx.explicit_nhwc = explicit_nhwc
ctx.stride_1x1 = stride_1x1 ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size ctx.spatial_group_size = spatial_group_size
if spatial_group_size > 1: if spatial_group_size > 1:
...@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if ctx.downsample: if ctx.downsample:
t_list.append(ctx.saved_tensors[10]) t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list) grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads) grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here # do halo exchange of grad_out2 here
# compute halo cells for grad_out1 # compute halo cells for grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with torch.cuda.stream(ctx.stream2): with torch.cuda.stream(ctx.stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc:
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_halo) btm_fat_halo[:,2:,:,:].copy_(btm_halo)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_() btm_relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo) else:
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_halo)
btm_relu_halo[:,:,:2,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,:,2:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo)
if ctx.explicit_nhwc:
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:] btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
else:
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
if ctx.spatial_group_rank > 0: if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc:
top_fat_halo[:,:1,:,:].copy_(top_halo) top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:1,:,:].zero_() top_relu_halo[:,:1,:,:].zero_()
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo) else:
top_fat_halo[:,:,:1,:].copy_(top_halo)
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:,:1,:].zero_()
top_relu_halo[:,:,1:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo)
if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10) inc.add_delay(10)
wgrad2_stream = torch.cuda.Stream() wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream()) wgrad2_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(wgrad2_stream): with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else: else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells # compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) #wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos # apply wgrad2 halos
#if ctx.spatial_group_size > 1: #if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0: # if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:] # top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo) # top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo) # wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1: # if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:] # btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo) # btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo) # wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply halo cells to grad_out1 # apply halo cells to grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_group_rank > 0: if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1) torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.explicit_nhwc:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
else:
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.spatial_group_rank < ctx.spatial_group_size-1: if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2) torch.cuda.current_stream().wait_stream(ctx.stream2)
if ctx.explicit_nhwc:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
else:
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream) torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, *grads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment