Commit 834b1d01 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Concatenate out1 with halos for backward

parent e5d0be82
......@@ -220,10 +220,11 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_stream, nhwc, stride_1x1, scale, bias, x, *conv):
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, nhwc, stride_1x1, scale, bias, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
stream3 = spatial_halo_exchanger.stream3
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
......@@ -239,13 +240,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
if spatial_group_size > 1:
out1 = outputs[0]
# TODO: This assumes explicit nhwc
N,Hs,W,C = list(out1.shape)
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
stream1.wait_stream(torch.cuda.current_stream())
stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3):
out1_pad[:,1:Hs+1,:,:].copy_(out1)
with torch.cuda.stream(stream1):
top_out1_halo, btm_out1_halo = spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:])
top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
......@@ -253,12 +260,18 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
else:
with torch.cuda.stream(stream2):
btm_out1_halo.zero_()
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
else:
with torch.cuda.stream(stream1):
top_out1_halo.zero_()
inc.add_delay(10)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
......@@ -272,11 +285,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
out2[:,Hs-1:,:,:].copy_(btm_out2)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
# save halos for backward pass
if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[top_out1_halo,btm_out1_halo]))
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else:
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
......@@ -286,8 +300,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_size > 1:
ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.stream1 = stream1
ctx.stream2 = stream2
ctx.stream3 = stream3
return outputs[2]
# backward relu is not exposed, MUL with mask used now
......@@ -295,9 +311,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_o):
if ctx.spatial_group_size > 1:
top_out1_halo = ctx.saved_tensors[-2]
btm_out1_halo = ctx.saved_tensors[-1]
outputs = ctx.saved_tensors[-5:-2]
out1_pad = ctx.saved_tensors[-1]
outputs = ctx.saved_tensors[-4:-1]
else:
outputs = ctx.saved_tensors[-3:]
......@@ -353,19 +368,24 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
inc.add_delay(10)
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
if ctx.spatial_group_size > 1:
if ctx.spatial_group_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
......@@ -456,7 +476,7 @@ class SpatialBottleneck(torch.nn.Module):
# spatial communicator
if spatial_parallel_args is None:
self.spatial_parallel_args = (1, 0, None, None, None)
self.spatial_parallel_args = (1, 0, None, None, 0)
else:
self.spatial_parallel_args = spatial_parallel_args
return
......
......@@ -12,6 +12,7 @@ class HaloExchanger(object):
def __init__(self):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
......
......@@ -1936,6 +1936,7 @@ struct bottleneck_backward_state {
int axis[4];
int64_t outdimA1[4]; // grad_out1
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
......@@ -1953,6 +1954,7 @@ struct bottleneck_backward_state {
int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim1h[4];
......@@ -2001,6 +2003,7 @@ struct bottleneck_backward_state {
// output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
......@@ -2022,6 +2025,13 @@ struct bottleneck_backward_state {
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
......@@ -2051,6 +2061,7 @@ struct bottleneck_backward_state {
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
......@@ -2063,6 +2074,7 @@ struct bottleneck_backward_state {
}
for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]];
......@@ -2234,6 +2246,39 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
return grad_out1_halo;
}
at::Tensor bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos)
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2, // dw2.shape
backward_state.outdimA2, // dy2.shape
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
return wgrad2;
}
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2480,6 +2525,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment