Commit 2f164a2a authored by Thor Johnsen's avatar Thor Johnsen
Browse files

First release

parent d6b5ae5d
from .bottleneck import Bottleneck
from .bottleneck import Bottleneck, SpatialBottleneck
import torch
import torch.distributed as dist
from torch import nn
import fast_bottleneck
......@@ -212,3 +213,235 @@ class Bottleneck(torch.nn.Module):
out = self.relu(out)
return out
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, spatial_group_size, local_rank, comm, stream1, nhwc, stride_1x1, scale, bias, x, *conv):
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3
if ctx.downsample:
args.append(conv[3])
args.append(scale[3])
args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
if spatial_group_size > 1:
out1 = outputs[0]
N,Hs,W,C = list(out1.shape)
padded_out1 = torch.empty((N,Hs+2,W,C),dtype=out1.dtype,device=out1.device)
padded_out1[:,1:Hs+1,:,:].copy_(out1)
stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1):
# copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=out1.dtype,device=out1.device)
send_halos[:,:1,:,:].copy_(out1[:,:1,:,:])
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)]
dist.all_gather(all_halos,send_halos)
padded_out1_top_halo = padded_out1[:,:1,:,:]
if local_rank > 0:
top_halo = all_halos[local_rank-1][:,1:,:,:]
padded_out1_top_halo.copy_(top_halo)
fat_top_halo = padded_out1[:,:3,:,:]
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_top_halo, args)
else:
padded_out1_top_halo.zero_()
padded_out1_btm_halo = padded_out1[:,Hs+1:,:,:]
if local_rank < spatial_group_size-1:
btm_halo = all_halos[local_rank+1][:,:1,:,:]
padded_out1_btm_halo.copy_(btm_halo)
fat_btm_halo = padded_out1[:,Hs-1:,:,:]
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_btm_halo, args)
else:
padded_out1_btm_halo.zero_()
torch.cuda.current_stream().wait_stream(stream1)
out2 = outputs[1]
if local_rank > 0:
out2[:,:1,:,:].copy_(top_out2)
if local_rank < spatial_group_size-1:
out2[:,Hs-1:,:,:].copy_(btm_out2)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[padded_out1]))
else:
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size
ctx.local_rank = local_rank
ctx.comm = comm
ctx.stream1 = stream1
return outputs[2]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
if ctx.spatial_group_size > 1:
outputs = ctx.saved_tensors[-4:-1]
else:
outputs = ctx.saved_tensors[-3:]
if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
else:
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
# create input vector for backward
t_list = [*ctx.saved_tensors[0:10]]
t_list.append(grad_conv3)
t_list.append(grad_conv4)
# outputs used for wgrad and generating drelu mask
t_list.append(outputs[0])
t_list.append(outputs[1])
# in case there is downsample
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
return (None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply
class SpatialBottleneck(torch.nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1):
super(SpatialBottleneck, self).__init__()
if groups != 1:
raise RuntimeError('Only support groups == 1')
if dilation != 1:
raise RuntimeError('Only support dilation == 1')
if norm_func == None:
norm_func = FrozenBatchNorm2d
else:
raise RuntimeError('Only support frozen BN now.')
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
conv1x1(in_channels, out_channels, stride),
norm_func(out_channels),
)
else:
self.downsample = None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
self.conv3 = conv1x1(bottleneck_channels, out_channels)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.use_cudnn = use_cudnn
# setup conv weights
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
if self.downsample is not None:
self.w_conv.append(self.downsample[0].weight)
# init weight in nchw format before possible transpose
for w in self.w_conv:
kaiming_uniform_(w, a=1)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self.explicit_nhwc = explicit_nhwc
if self.explicit_nhwc:
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous()
# spatial communicator
self.spatial_group_size = spatial_group_size
if spatial_group_size > 1:
world_size = dist.get_world_size()
num_groups = world_size // spatial_group_size
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else:
self.spatial_args = 1, 0, None, None
return
def forward(self, x):
if self.use_cudnn:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
return out
if self.explicit_nhwc:
raise RuntimeError('explicit nhwc with native ops is not supported.')
# fallback to native ops
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
......@@ -1606,7 +1606,726 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
return outputs;
}
namespace {
struct bottleneck_forward_status {
int64_t dimA[4];
int64_t filterdimA1[4];
int64_t filterdimA2[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int axis[4];
int64_t outdimA0[4];
int64_t outdimA1[4];
int64_t outdimA2[4];
int64_t outdimA3[4];
int64_t outdimA4[4];
int64_t padA[2];
int64_t padA1[2];
int64_t padA2[2]; // halo padding
int64_t dilationA[2];
int64_t convstrideA[2];
int64_t convstride1X1[2];
int64_t outdim0[4]; // halo input shape
int64_t outdim1[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
} else {
axis[0] = 0;
axis[1] = 1;
axis[2] = 2;
axis[3] = 3;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[10].size(axis[dim]);
}
}
// output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
padA1[0] = 1; padA1[1] = 1;
padA2[0] = 0; padA2[1] = 1;
dilationA[0] = 1; dilationA[1] = 1;
convstrideA[0] = 1; convstrideA[1] = 1;
convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1;
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA0[dim] = 3;
outdimA4[dim] = 1;
} else {
outdimA0[dim] = outdimA1[dim];
outdimA4[dim] = outdimA2[dim];
}
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]];
}
}
};
bottleneck_forward_status forward_state;
} // end of anonymous namespace
std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method.
// NB! We use a global object to store state.
forward_state.init(explicit_nhwc, stride_1X1, inputs);
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format);
auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format);
auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format);
outputs.push_back(out1);
outputs.push_back(out2);
outputs.push_back(out3);
return outputs;
}
// inputs contains x,w,z,b,(i)
void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* b = inputs[7].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(forward_state.dimA,
forward_state.padA,
forward_state.convstride1X1,
forward_state.dilationA,
forward_state.filterdimA1,
forward_state.outdimA1,
CUDNN_DATA_HALF,
x,
w,
y1,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item<float>());
}
// computes halo (top or bottom) from fat halo input.
// fat halo input is 3 pixels wide in H.
at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, std::vector<at::Tensor> inputs) {
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
at::Half* y1 = fat_halo_y1.data_ptr<at::Half>();
auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);
at::Half* y2 = halo_y2.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(forward_state.outdimA0,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA4,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
return halo_y2;
}
void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
// create output of conv3
auto out3 = outputs[2];
at::Half* y3 = out3.data_ptr<at::Half>();
// create output of conv4 that may exist
auto identity = at::empty_like(out3);
at::Half* yi = identity.data_ptr<at::Half>();
at::Half *w, *z, *b;
if (stride_1X1 != 1 || forward_state.filterdimA3[0] != forward_state.dimA[1]){
w = inputs[10].data_ptr<at::Half>();
z = inputs[11].data_ptr<at::Half>();
b = inputs[12].data_ptr<at::Half>();
run_conv_scale_bias(forward_state.dimA,
forward_state.padA,
forward_state.convstride1X1,
forward_state.dilationA,
forward_state.filterdimA4,
forward_state.outdimA3,
CUDNN_DATA_HALF,
x,
w,
yi,
z,
b);
DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item<float>());
}
else {
yi = x;
}
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
w = inputs[3].data_ptr<at::Half>();
z = inputs[6].data_ptr<at::Half>();
b = inputs[9].data_ptr<at::Half>();
run_conv_scale_bias_add_activation(forward_state.outdimA2,
forward_state.padA,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA3,
forward_state.outdimA3,
CUDNN_DATA_HALF,
y2,
w,
y3,
z,
b,
yi);
DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item<float>());
}
namespace {
struct bottleneck_backward_state {
int64_t dimA[4];
int64_t filterdimA1[4];
int64_t filterdimA2[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int axis[4];
int64_t outdimA1[4];
int64_t outdimA2[4];
int64_t outdimA3[4];
int64_t padA[2];
int64_t padA1[2];
int64_t dilationA[2];
int64_t convstrideA[2];
int64_t convstride1X1[2];
int64_t outdim1[4];
int64_t outdim2[4];
int64_t outdim3[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
} else {
axis[0] = 0;
axis[1] = 1;
axis[2] = 2;
axis[3] = 3;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[14].size(axis[dim]);
}
}
// output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
padA1[0] = 1; padA1[1] = 1;
dilationA[0] = 1; dilationA[1] = 1;
convstrideA[0] = 1; convstrideA[1] = 1;
convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1;
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
}
}
};
bottleneck_backward_state backward_state;
}
std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
std::cout << std::fixed;
backward_state.init(explicit_nhwc, stride_1X1, inputs);
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
auto grad_x = at::empty_like(inputs[0]);
auto wgrad1 = at::empty_like(inputs[1]);
auto wgrad2 = at::empty_like(inputs[2]);
auto wgrad3 = at::empty_like(inputs[3]);
outputs.push_back(grad_x);
outputs.push_back(wgrad1);
outputs.push_back(wgrad2);
outputs.push_back(wgrad3);
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
auto wgrad4 = at::empty_like(inputs[14]);
outputs.push_back(wgrad4);
}
return outputs;
}
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// wgrad
auto wgrad3 = outputs[3];
at::Half* dw3 = wgrad3.data_ptr<at::Half>();
run_dconv(backward_state.outdimA2,
backward_state.padA,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA3,
backward_state.outdimA3,
CUDNN_DATA_HALF,
conv_in,
dw3,
dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
at::Half* w = inputs[3].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* relu2 = inputs[13].data_ptr<at::Half>();
run_dconv_drelu_dscale(backward_state.outdimA2,
backward_state.padA,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA3,
backward_state.outdimA3,
CUDNN_DATA_HALF,
dy2,
w,
dy3,
z,
relu2);
// do halo exchange of dy2 here
DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item<float>());
return grad_out2;
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = inputs[12].data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item<float>());
// create grads of conv4 that may exist
auto grad_x_conv4 = at::empty_like(inputs[0]);
at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();
at::Tensor wgrad4;
// x used for dconv1 and dconv4 wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]){
w = inputs[14].data_ptr<at::Half>();
at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();
if (requires_grad) {
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA4,
backward_state.outdimA3,
CUDNN_DATA_HALF,
dx_conv4,
w,
dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4 = outputs[4];
at::Half* dw4 = wgrad4.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA4,
backward_state.outdimA3,
CUDNN_DATA_HALF,
x,
dw4,
dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
else {
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4 = inputs[11].data_ptr<at::Half>();
}
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
w = inputs[1].data_ptr<at::Half>();
auto grad_x = outputs[0];
at::Half* dx = grad_x.data_ptr<at::Half>();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (requires_grad){
if (stride_1X1 != 1){
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
dx,
w,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// add 2 together
grad_x.add_(grad_x_conv4);
}
else {
run_dconv_add(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
dx,
w,
dy1,
dx_conv4);
}
}
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &bottleneck_forward, "Bottleneck block forward");
m.def("backward", &bottleneck_backward, "Bottleneck block backward");
m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init");
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment