Commit 22e2caf6 authored by yhcao6's avatar yhcao6
Browse files

rearange formal parameters to follow pytorch-stype, optimize mod-dcn

parent 03c38bdc
......@@ -29,10 +29,11 @@ class DeformConvFunction(Function):
ctx.save_for_backward(input, offset, weight)
output = input.new(*DeformConvFunction._output_size(
input, weight, ctx.padding, ctx.dilation, ctx.stride))
output = input.new_empty(
DeformConvFunction._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new(), input.new()] # columns, ones
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
......@@ -110,24 +111,26 @@ class ModulatedDeformConvFunction(Function):
stride,
padding,
dilation=1,
deformable_groups=1):
deformable_groups=1,
with_bias=False):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.deformable_groups = deformable_groups
ctx.with_bias = with_bias
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()]
output = input.new_empty(
ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
......@@ -145,10 +148,10 @@ class ModulatedDeformConvFunction(Function):
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
ctx.deformable_groups, ctx.with_bias)
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None)
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
......
......@@ -32,20 +32,20 @@ class DeformRoIPoolingFunction(Function):
if not data.is_cuda:
raise NotImplementedError
output = data.new(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output_count = data.new(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output = data.new_empty(
DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output_count = data.new_empty(
DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
# if data.requires_grad or rois.requires_grad or offset.requires_grad:
# ctx.save_for_backward(data, rois, offset, output_count)
ctx.data = data
ctx.rois = rois
ctx.offset = offset
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
# ctx.data = data
# ctx.rois = rois
# ctx.offset = offset
ctx.output_count = output_count
return output
......@@ -55,10 +55,10 @@ class DeformRoIPoolingFunction(Function):
if not grad_output.is_cuda:
raise NotImplementedError
# data, rois, offset, output_count = ctx.saved_tensors
data = ctx.data
rois = ctx.rois
offset = ctx.offset
data, rois, offset = ctx.saved_tensors
# data = ctx.data
# rois = ctx.rois
# offset = ctx.offset
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_offset = torch.zeros_like(offset)
......
......@@ -16,7 +16,9 @@ class DeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
deformable_groups=1):
deformable_groups=1,
bias=None):
assert bias is None
super(DeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -40,8 +42,7 @@ class DeformConv(nn.Module):
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.deformable_groups)
self.padding, self.dilation, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -50,11 +51,11 @@ class ModulatedDeformConv(nn.Module):
in_channels,
out_channels,
kernel_size,
stride,
padding,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
bias=False):
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -63,13 +64,15 @@ class ModulatedDeformConv(nn.Module):
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = nn.Parameter(torch.Tensor([0])) # fake tensor
self.reset_parameters()
if not bias:
self.bias.requires_grad = False
def reset_parameters(self):
n = self.in_channels
......@@ -77,12 +80,14 @@ class ModulatedDeformConv(nn.Module):
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
if self.with_bias:
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
self.dilation, self.deformable_groups,
self.with_bias)
class ModulatedDeformConvPack(ModulatedDeformConv):
......@@ -91,8 +96,8 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
in_channels,
out_channels,
kernel_size,
stride,
padding,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
bias=True):
......@@ -121,4 +126,5 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
self.dilation, self.deformable_groups,
self.with_bias)
......@@ -26,7 +26,7 @@ class DeformRoIPooling(nn.Module):
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new()
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
......@@ -74,10 +74,10 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def forward(self, data, rois):
if self.no_trans:
offset = data.new()
offset = data.new_empty(0)
else:
n = rois.shape[0]
offset = data.new()
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True,
self.group_size, self.part_size,
......@@ -129,10 +129,10 @@ class DeformRoIPoolingPack(DeformRoIPooling):
def forward(self, data, rois):
if self.no_trans:
offset = data.new()
offset = data.new_empty(0)
else:
n = rois.shape[0]
offset = data.new()
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True,
self.group_size, self.part_size,
......
......@@ -421,7 +421,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int deformable_group)
const int deformable_group, const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
......@@ -454,28 +454,24 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
}
// resize output
output = output.view({batch, channels_out, height_out, width_out});
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
output[b] = output[b].flatten(1).addmm_(bias.view({-1, 1}), ones.view({1, -1}), 0.0f, 1.0f).view_as(output[b]);
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
//(k * m) x (m * n)
// Y = WC
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
}
if (with_bias){
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
......@@ -489,7 +485,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
int deformable_group, const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
......@@ -551,7 +547,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
if (with_bias){
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment