Commit 22e2caf6 authored by yhcao6's avatar yhcao6
Browse files

rearange formal parameters to follow pytorch-stype, optimize mod-dcn

parent 03c38bdc
...@@ -29,10 +29,11 @@ class DeformConvFunction(Function): ...@@ -29,10 +29,11 @@ class DeformConvFunction(Function):
ctx.save_for_backward(input, offset, weight) ctx.save_for_backward(input, offset, weight)
output = input.new(*DeformConvFunction._output_size( output = input.new_empty(
input, weight, ctx.padding, ctx.dilation, ctx.stride)) DeformConvFunction._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new(), input.new()] # columns, ones ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda: if not input.is_cuda:
raise NotImplementedError raise NotImplementedError
...@@ -110,24 +111,26 @@ class ModulatedDeformConvFunction(Function): ...@@ -110,24 +111,26 @@ class ModulatedDeformConvFunction(Function):
stride, stride,
padding, padding,
dilation=1, dilation=1,
deformable_groups=1): deformable_groups=1,
with_bias=False):
ctx.stride = stride ctx.stride = stride
ctx.padding = padding ctx.padding = padding
ctx.dilation = dilation ctx.dilation = dilation
ctx.deformable_groups = deformable_groups ctx.deformable_groups = deformable_groups
ctx.with_bias = with_bias
if not input.is_cuda: if not input.is_cuda:
raise NotImplementedError raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad: or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias) ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new( output = input.new_empty(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()] ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_cuda.modulated_deform_conv_cuda_forward( deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output, input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups) ctx.deformable_groups, ctx.with_bias)
return output return output
@staticmethod @staticmethod
...@@ -145,10 +148,10 @@ class ModulatedDeformConvFunction(Function): ...@@ -145,10 +148,10 @@ class ModulatedDeformConvFunction(Function):
grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride, grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups) ctx.deformable_groups, ctx.with_bias)
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None) None, None, None, None, None)
@staticmethod @staticmethod
def _infer_shape(ctx, input, weight): def _infer_shape(ctx, input, weight):
......
...@@ -32,20 +32,20 @@ class DeformRoIPoolingFunction(Function): ...@@ -32,20 +32,20 @@ class DeformRoIPoolingFunction(Function):
if not data.is_cuda: if not data.is_cuda:
raise NotImplementedError raise NotImplementedError
output = data.new( output = data.new_empty(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois)) DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output_count = data.new( output_count = data.new_empty(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois)) DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
deform_pool_cuda.deform_psroi_pooling_cuda_forward( deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans, data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size, ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std) ctx.part_size, ctx.sample_per_part, ctx.trans_std)
# if data.requires_grad or rois.requires_grad or offset.requires_grad: if data.requires_grad or rois.requires_grad or offset.requires_grad:
# ctx.save_for_backward(data, rois, offset, output_count) ctx.save_for_backward(data, rois, offset)
ctx.data = data # ctx.data = data
ctx.rois = rois # ctx.rois = rois
ctx.offset = offset # ctx.offset = offset
ctx.output_count = output_count ctx.output_count = output_count
return output return output
...@@ -55,10 +55,10 @@ class DeformRoIPoolingFunction(Function): ...@@ -55,10 +55,10 @@ class DeformRoIPoolingFunction(Function):
if not grad_output.is_cuda: if not grad_output.is_cuda:
raise NotImplementedError raise NotImplementedError
# data, rois, offset, output_count = ctx.saved_tensors data, rois, offset = ctx.saved_tensors
data = ctx.data # data = ctx.data
rois = ctx.rois # rois = ctx.rois
offset = ctx.offset # offset = ctx.offset
output_count = ctx.output_count output_count = ctx.output_count
grad_input = torch.zeros_like(data) grad_input = torch.zeros_like(data)
grad_offset = torch.zeros_like(offset) grad_offset = torch.zeros_like(offset)
......
...@@ -16,7 +16,9 @@ class DeformConv(nn.Module): ...@@ -16,7 +16,9 @@ class DeformConv(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
deformable_groups=1): deformable_groups=1,
bias=None):
assert bias is None
super(DeformConv, self).__init__() super(DeformConv, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -40,8 +42,7 @@ class DeformConv(nn.Module): ...@@ -40,8 +42,7 @@ class DeformConv(nn.Module):
def forward(self, input, offset): def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride, return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.padding, self.dilation, self.deformable_groups)
self.deformable_groups)
class ModulatedDeformConv(nn.Module): class ModulatedDeformConv(nn.Module):
...@@ -50,11 +51,11 @@ class ModulatedDeformConv(nn.Module): ...@@ -50,11 +51,11 @@ class ModulatedDeformConv(nn.Module):
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
stride, stride=1,
padding, padding=0,
dilation=1, dilation=1,
deformable_groups=1, deformable_groups=1,
bias=False): bias=True):
super(ModulatedDeformConv, self).__init__() super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -63,13 +64,15 @@ class ModulatedDeformConv(nn.Module): ...@@ -63,13 +64,15 @@ class ModulatedDeformConv(nn.Module):
self.padding = padding self.padding = padding
self.dilation = dilation self.dilation = dilation
self.deformable_groups = deformable_groups self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size)) torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels)) if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = nn.Parameter(torch.Tensor([0])) # fake tensor
self.reset_parameters() self.reset_parameters()
if not bias:
self.bias.requires_grad = False
def reset_parameters(self): def reset_parameters(self):
n = self.in_channels n = self.in_channels
...@@ -77,12 +80,14 @@ class ModulatedDeformConv(nn.Module): ...@@ -77,12 +80,14 @@ class ModulatedDeformConv(nn.Module):
n *= k n *= k
stdv = 1. / math.sqrt(n) stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv) self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_() if self.with_bias:
self.bias.data.zero_()
def forward(self, input, offset, mask): def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups) self.dilation, self.deformable_groups,
self.with_bias)
class ModulatedDeformConvPack(ModulatedDeformConv): class ModulatedDeformConvPack(ModulatedDeformConv):
...@@ -91,8 +96,8 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -91,8 +96,8 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
stride, stride=1,
padding, padding=0,
dilation=1, dilation=1,
deformable_groups=1, deformable_groups=1,
bias=True): bias=True):
...@@ -121,4 +126,5 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -121,4 +126,5 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
mask = torch.sigmoid(mask) mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups) self.dilation, self.deformable_groups,
self.with_bias)
...@@ -26,7 +26,7 @@ class DeformRoIPooling(nn.Module): ...@@ -26,7 +26,7 @@ class DeformRoIPooling(nn.Module):
def forward(self, data, rois, offset): def forward(self, data, rois, offset):
if self.no_trans: if self.no_trans:
offset = data.new() offset = data.new_empty(0)
return deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size, self.output_dim, self.no_trans, self.group_size, self.part_size,
...@@ -74,10 +74,10 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -74,10 +74,10 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def forward(self, data, rois): def forward(self, data, rois):
if self.no_trans: if self.no_trans:
offset = data.new() offset = data.new_empty(0)
else: else:
n = rois.shape[0] n = rois.shape[0]
offset = data.new() offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale, x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True, self.out_size, self.output_dim, True,
self.group_size, self.part_size, self.group_size, self.part_size,
...@@ -129,10 +129,10 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -129,10 +129,10 @@ class DeformRoIPoolingPack(DeformRoIPooling):
def forward(self, data, rois): def forward(self, data, rois):
if self.no_trans: if self.no_trans:
offset = data.new() offset = data.new_empty(0)
else: else:
n = rois.shape[0] n = rois.shape[0]
offset = data.new() offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale, x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True, self.out_size, self.output_dim, True,
self.group_size, self.part_size, self.group_size, self.part_size,
......
...@@ -421,7 +421,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, ...@@ -421,7 +421,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
const int stride_h, const int stride_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int dilation_h, const int dilation_w,
const int deformable_group) const int deformable_group, const bool with_bias)
{ {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
...@@ -454,28 +454,24 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight, ...@@ -454,28 +454,24 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
} }
// resize output // resize output
output = output.view({batch, channels_out, height_out, width_out}); output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns // resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type()); columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
for (int b = 0; b < batch; b++) for (int b = 0; b < batch; b++)
{ {
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
output[b] = output[b].flatten(1).addmm_(bias.view({-1, 1}), ones.view({1, -1}), 0.0f, 1.0f).view_as(output[b]);
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width, 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns); deformable_group, columns);
//(k * m) x (m * n)
// Y = WC
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]); output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
} }
if (with_bias){
output += bias.view({1, bias.size(0), 1, 1});
}
} }
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
...@@ -489,7 +485,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, ...@@ -489,7 +485,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
int stride_h, int stride_w, int stride_h, int stride_w,
int pad_h, int pad_w, int pad_h, int pad_w,
int dilation_h, int dilation_w, int dilation_h, int dilation_w,
int deformable_group) int deformable_group, const bool with_bias)
{ {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
...@@ -551,7 +547,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight, ...@@ -551,7 +547,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight); grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1); if (with_bias){
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment