Commit 834b1d01 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Concatenate out1 with halos for backward

parent e5d0be82
...@@ -220,10 +220,11 @@ class Bottleneck(torch.nn.Module): ...@@ -220,10 +220,11 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_stream, nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, nhwc, stride_1x1, scale, bias, x, *conv):
if spatial_group_size > 1: if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1 stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2 stream2 = spatial_halo_exchanger.stream2
stream3 = spatial_halo_exchanger.stream3
# TODO: clean up order of tensors # TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
...@@ -239,13 +240,19 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -239,13 +240,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args) outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
if spatial_group_size > 1: if spatial_group_size > 1:
out1 = outputs[0] out1 = outputs[0]
# TODO: This assumes explicit nhwc
N,Hs,W,C = list(out1.shape) N,Hs,W,C = list(out1.shape)
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
stream1.wait_stream(torch.cuda.current_stream()) stream1.wait_stream(torch.cuda.current_stream())
stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3):
out1_pad[:,1:Hs+1,:,:].copy_(out1)
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
top_out1_halo, btm_out1_halo = spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:]) top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
...@@ -253,12 +260,18 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -253,12 +260,18 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
else:
with torch.cuda.stream(stream2):
btm_out1_halo.zero_()
if spatial_group_rank > 0: if spatial_group_rank > 0:
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo) top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args) top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
else:
with torch.cuda.stream(stream1):
top_out1_halo.zero_()
inc.add_delay(10) inc.add_delay(10)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
...@@ -272,11 +285,12 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -272,11 +285,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2) torch.cuda.current_stream().wait_stream(stream2)
out2[:,Hs-1:,:,:].copy_(btm_out2) out2[:,Hs-1:,:,:].copy_(btm_out2)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[top_out1_halo,btm_out1_halo])) ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
ctx.save_for_backward(*(args+outputs)) ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu # save relu outputs for drelu
...@@ -286,8 +300,10 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -286,8 +300,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_size > 1: if spatial_group_size > 1:
ctx.spatial_group_rank = spatial_group_rank ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.stream1 = stream1 ctx.stream1 = stream1
ctx.stream2 = stream2 ctx.stream2 = stream2
ctx.stream3 = stream3
return outputs[2] return outputs[2]
# backward relu is not exposed, MUL with mask used now # backward relu is not exposed, MUL with mask used now
...@@ -295,9 +311,8 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -295,9 +311,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_o): def backward(ctx, grad_o):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
top_out1_halo = ctx.saved_tensors[-2] out1_pad = ctx.saved_tensors[-1]
btm_out1_halo = ctx.saved_tensors[-1] outputs = ctx.saved_tensors[-4:-1]
outputs = ctx.saved_tensors[-5:-2]
else: else:
outputs = ctx.saved_tensors[-3:] outputs = ctx.saved_tensors[-3:]
...@@ -353,19 +368,24 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -353,19 +368,24 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
inc.add_delay(10) inc.add_delay(10)
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells # compute wgrad2 for internal cells
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) #wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos # apply wgrad2 halos
if ctx.spatial_group_size > 1: #if ctx.spatial_group_size > 1:
if ctx.spatial_group_rank > 0: # if ctx.spatial_group_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:] # top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo) # top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo) # wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1: # if ctx.spatial_group_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:] # btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo) # btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo) # wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
...@@ -456,7 +476,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -456,7 +476,7 @@ class SpatialBottleneck(torch.nn.Module):
# spatial communicator # spatial communicator
if spatial_parallel_args is None: if spatial_parallel_args is None:
self.spatial_parallel_args = (1, 0, None, None, None) self.spatial_parallel_args = (1, 0, None, None, 0)
else: else:
self.spatial_parallel_args = spatial_parallel_args self.spatial_parallel_args = spatial_parallel_args
return return
......
...@@ -12,6 +12,7 @@ class HaloExchanger(object): ...@@ -12,6 +12,7 @@ class HaloExchanger(object):
def __init__(self): def __init__(self):
self.stream1 = torch.cuda.Stream() self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream() self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
class HaloExchangerNoComm(HaloExchanger): class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, world_size, spatial_group_size, rank, comm):
......
...@@ -1936,6 +1936,7 @@ struct bottleneck_backward_state { ...@@ -1936,6 +1936,7 @@ struct bottleneck_backward_state {
int axis[4]; int axis[4];
int64_t outdimA1[4]; // grad_out1 int64_t outdimA1[4]; // grad_out1
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; // grad_out2 int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4]; int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
...@@ -1953,6 +1954,7 @@ struct bottleneck_backward_state { ...@@ -1953,6 +1954,7 @@ struct bottleneck_backward_state {
int64_t filterdim2hh[4]; // Cin,1,3,Cout int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4]; int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4]; int64_t outdim2[4];
int64_t outdim3[4]; int64_t outdim3[4];
int64_t outdim1h[4]; int64_t outdim1h[4];
...@@ -2001,6 +2003,7 @@ struct bottleneck_backward_state { ...@@ -2001,6 +2003,7 @@ struct bottleneck_backward_state {
// output dim in n,c,h,w used by backend // output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0; outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
...@@ -2022,6 +2025,13 @@ struct bottleneck_backward_state { ...@@ -2022,6 +2025,13 @@ struct bottleneck_backward_state {
for (int dim = 0; dim < 2; dim++) { for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
} }
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0]; outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0]; outdimA2[1] = filterdimA2[0];
...@@ -2051,6 +2061,7 @@ struct bottleneck_backward_state { ...@@ -2051,6 +2061,7 @@ struct bottleneck_backward_state {
// Create output tensor in the correct shape in pytorch's view // Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
...@@ -2063,6 +2074,7 @@ struct bottleneck_backward_state { ...@@ -2063,6 +2074,7 @@ struct bottleneck_backward_state {
} }
for (int dim=0;dim<4;dim++) { for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]]; outdim1h[dim] = outdimA1h[axis[dim]];
...@@ -2234,6 +2246,39 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1 ...@@ -2234,6 +2246,39 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
return grad_out1_halo; return grad_out1_halo;
} }
at::Tensor bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos)
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2, // dw2.shape
backward_state.outdimA2, // dy2.shape
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
return wgrad2;
}
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) { at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
...@@ -2480,6 +2525,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -2480,6 +2525,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward"); m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward"); m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward"); m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward"); m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward"); m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment