Unverified Commit 8f2d1583 authored by Zeqiang Lai's avatar Zeqiang Lai Committed by GitHub
Browse files

Support removing center point for DCNV3 (#101)

* support remove center point for DCNV3

* dcnv3: update test

* remove center for DCNv3_pytorch

* fix pytorch version

* fix unit test

* add backward compatibility

* bump dcn version to 1.1
parent b64d9ca3
......@@ -79,6 +79,7 @@ _C.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR = False
_C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM = False
_C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS = None
_C.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE = False
_C.MODEL.INTERN_IMAGE.REMOVE_CENTER = False
......
......@@ -582,7 +582,7 @@ if __name__ == '__main__':
assert has_native_amp, "Please update pytorch(1.6+) to support amp!"
# init distributed env
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_NNODES']) != 1:
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1:
print("\nDist init: SLURM")
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
......
......@@ -26,7 +26,8 @@ def build_model(config):
use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G
level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE # for InternImage-H/G
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
)
else:
raise NotImplementedError(f"Unkown model: {model_type}")
......
......@@ -359,7 +359,9 @@ class InternImageLayer(nn.Module):
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.groups = groups
......@@ -379,7 +381,9 @@ class InternImageLayer(nn.Module):
act_layer=act_layer,
norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale) # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
......@@ -463,7 +467,9 @@ class InternImageBlock(nn.Module):
dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.depth = depth
......@@ -487,7 +493,8 @@ class InternImageBlock(nn.Module):
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center = remove_center, # for InternImage-H/G
) for i in range(depth)
])
if not self.post_norm or center_feature_scale:
......@@ -567,6 +574,7 @@ class InternImage(nn.Module):
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
**kwargs):
super().__init__()
self.core_op = core_op
......@@ -579,6 +587,8 @@ class InternImage(nn.Module):
self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.remove_center = remove_center
print(f'using core type: {core_op}')
print(f'using activation layer: {act_layer}')
print(f'using main norm layer: {norm_layer}')
......@@ -586,6 +596,7 @@ class InternImage(nn.Module):
print(f"level2_post_norm: {level2_post_norm}")
print(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
print(f"res_post_norm: {res_post_norm}")
print(f"remove_center: {remove_center}")
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
......@@ -623,7 +634,8 @@ class InternImage(nn.Module):
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.levels.append(level)
......
......@@ -23,7 +23,7 @@ class DCNv3Function(Function):
ctx, input, offset, mask,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale, im2col_step):
group, group_channels, offset_scale, im2col_step, remove_center):
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride_h = stride_h
......@@ -36,11 +36,17 @@ class DCNv3Function(Function):
ctx.group_channels = group_channels
ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
kwargs = {}
if remove_center:
kwargs['remove_center'] = remove_center
output = DCNv3.dcnv3_forward(
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step)
group_channels, offset_scale, ctx.im2col_step, **kwargs)
ctx.save_for_backward(input, offset, mask)
return output
......@@ -50,20 +56,25 @@ class DCNv3Function(Function):
@custom_bwd
def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors
kwargs = {}
if ctx.remove_center:
kwargs['remove_center'] = ctx.remove_center
grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward(
input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step)
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step, **kwargs)
return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None
None, None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, im2col_step):
group_channels, offset_scale, im2col_step, remove_center):
"""Symbolic function for mmdeploy::DCNv3.
Returns:
......@@ -86,6 +97,7 @@ class DCNv3Function(Function):
group_channels_i=int(group_channels),
offset_scale_f=float(offset_scale),
im2col_step_i=int(im2col_step),
remove_center=int(remove_center),
)
......@@ -126,14 +138,14 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil
x, y = torch.meshgrid(
torch.linspace(
-((dilation_w * (kernel_w - 1)) // 2),
-((dilation_w * (kernel_w - 1)) // 2) +
(kernel_w - 1) * dilation_w, kernel_w,
-((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w,
kernel_w,
dtype=torch.float32,
device=device),
torch.linspace(
-((dilation_h * (kernel_h - 1)) // 2),
-((dilation_h * (kernel_h - 1)) // 2) +
(kernel_h - 1) * dilation_h, kernel_h,
-((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h,
kernel_h,
dtype=torch.float32,
device=device))
......@@ -145,13 +157,24 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil
return grid
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
idx = list(range(sampling_locations.shape[-2]))
C = (kernel_w * kernel_h - 1)//2
idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0]
sampling_locations = sampling_locations[:,:,:,idx, :]
return sampling_locations
def dcnv3_core_pytorch(
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale):
group_channels, offset_scale, remove_center):
# for debug and test only,
# need to use cuda version instead
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
raise ValueError('remove_center is only compatible with square odd kernel size.')
input = F.pad(
input,
[0, 0, pad_h, pad_h, pad_w, pad_w])
......@@ -163,12 +186,15 @@ def dcnv3_core_pytorch(
grid = _generate_dilation_grids(
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device)
repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device)
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \
offset * offset_scale / spatial_norm
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
if remove_center:
sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
sampling_locations = sampling_locations.flatten(3, 4)
sampling_locations = sampling_locations + offset * offset_scale / spatial_norm
P_ = kernel_h * kernel_w
P_ = kernel_h * kernel_w - remove_center
sampling_grids = 2 * sampling_locations - 1
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
......
......@@ -101,7 +101,9 @@ class DCNv3_pytorch(nn.Module):
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False):
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
......@@ -137,6 +139,7 @@ class DCNv3_pytorch(nn.Module):
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
self.dw_conv = nn.Sequential(
nn.Conv2d(
......@@ -154,10 +157,10 @@ class DCNv3_pytorch(nn.Module):
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * kernel_size * kernel_size * 2)
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * kernel_size * kernel_size)
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
......@@ -202,7 +205,7 @@ class DCNv3_pytorch(nn.Module):
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale)
self.offset_scale, self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
......@@ -228,7 +231,9 @@ class DCNv3(nn.Module):
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False):
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
......@@ -264,6 +269,10 @@ class DCNv3(nn.Module):
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
if self.remove_center and self.kernel_size % 2 == 0:
raise ValueError('remove_center is only compatible with odd kernel size.')
self.dw_conv = nn.Sequential(
nn.Conv2d(
......@@ -281,10 +290,10 @@ class DCNv3(nn.Module):
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * kernel_size * kernel_size * 2)
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * kernel_size * kernel_size)
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
......@@ -321,7 +330,8 @@ class DCNv3(nn.Module):
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
mask = F.softmax(mask, -1)
mask = mask.reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
......@@ -331,7 +341,8 @@ class DCNv3(nn.Module):
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256)
256,
self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
......
......@@ -61,7 +61,7 @@ def get_extensions():
setup(
name="DCNv3",
version="1.0",
version="1.1",
author="InternImage",
url="https://github.com/OpenGVLab/InternImage",
description=
......
......@@ -25,7 +25,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels,
const float offset_scale, const int im2col_step) {
const float offset_scale, const int im2col_step, const int remove_center) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
......@@ -61,8 +61,8 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
width_out, group * group_channels});
auto per_input_size = height_in * width_in * group * group_channels;
auto per_offset_size =
height_out * width_out * group * kernel_h * kernel_w * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w;
height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2;
auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center);
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
// AT_DISPATCH_FLOATING_TYPES(
......@@ -77,7 +77,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
columns.data<scalar_t>(), kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, batch_n, height_in, width_in, height_out,
width_out, offset_scale);
width_out, offset_scale, remove_center);
}));
}
......@@ -91,7 +91,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step) {
const at::Tensor &grad_output, const int im2col_step, const int remove_center) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
......@@ -135,8 +135,8 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const int batch_n = im2col_step_;
auto per_input_size = height_in * width_in * group * group_channels;
auto per_offset_size =
height_out * width_out * group * kernel_h * kernel_w * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w;
height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2;
auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center);
auto grad_output_n =
grad_output.view({batch / im2col_step_, batch_n, height_out * width_out,
group, group_channels});
......@@ -155,7 +155,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, batch_n,
height_in, width_in, height_out, width_out, offset_scale,
height_in, width_in, height_out, width_out, offset_scale, remove_center,
grad_input.data<opmath_t>() +
n * im2col_step_ * per_input_size,
grad_offset.data<opmath_t>() +
......
......@@ -19,7 +19,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels,
const float offset_scale, const int im2col_step);
const float offset_scale, const int im2col_step, const int remove_center);
std::vector<at::Tensor>
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
......@@ -28,4 +28,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step);
const at::Tensor &grad_output, const int im2col_step, const int remove_center);
......@@ -221,7 +221,7 @@ __global__ void dcnv3_im2col_gpu_kernel(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale) {
const opmath_t offset_scale, const int remove_center) {
CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index;
const int c_col = _temp % group_channels;
......@@ -239,7 +239,7 @@ __global__ void dcnv3_im2col_gpu_kernel(
const int input_size = height_in * width_in;
scalar_t *data_col_ptr = data_col + index;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = group * group_channels;
......@@ -250,8 +250,14 @@ __global__ void dcnv3_im2col_gpu_kernel(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -270,6 +276,7 @@ __global__ void dcnv3_im2col_gpu_kernel(
data_loc_w_ptr += 2;
}
}
}
*data_col_ptr = col;
}
}
......@@ -283,7 +290,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
__shared__ opmath_t cache_grad_offset[blockSize * 2];
......@@ -305,7 +312,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
......@@ -319,8 +326,14 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -367,6 +380,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
}
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
......@@ -377,7 +391,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
__shared__ opmath_t cache_grad_offset[blockSize * 2];
......@@ -399,7 +413,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
......@@ -413,8 +427,14 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -463,6 +483,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
}
}
}
}
}
template <typename scalar_t>
......@@ -473,7 +494,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[];
......@@ -496,7 +517,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
......@@ -510,8 +531,14 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -558,6 +585,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
}
}
}
}
}
template <typename scalar_t>
......@@ -568,7 +596,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[];
......@@ -591,7 +619,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
......@@ -605,8 +633,14 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -664,6 +698,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
}
}
}
}
}
template <typename scalar_t>
......@@ -674,7 +709,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[];
......@@ -697,7 +732,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
......@@ -711,8 +746,14 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -770,6 +811,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
}
}
}
}
}
template <typename scalar_t>
......@@ -780,7 +822,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index;
......@@ -799,7 +841,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
......@@ -813,8 +855,14 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
// if not remove center, or remove center and not the center
if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
......@@ -836,6 +884,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
}
}
}
}
}
template <typename scalar_t>
......@@ -848,7 +897,7 @@ void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im,
const int group, const int group_channels,
const int batch_n, const int height_in,
const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale) {
const int width_out, const opmath_t offset_scale, const int remove_center) {
const int num_kernels =
batch_n * height_out * width_out * group * group_channels;
const int num_actual_kernels =
......@@ -859,7 +908,7 @@ void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im,
stream>>>(num_kernels, data_im, data_offset, data_mask, data_col,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale);
width_in, height_out, width_out, offset_scale, remove_center);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
......@@ -875,8 +924,8 @@ void dcnv3_col2im_cuda(
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int batch_n,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, opmath_t *grad_im,
opmath_t *grad_offset, opmath_t *grad_mask) {
const int width_out, const opmath_t offset_scale, const int remove_center,
opmath_t *grad_im, opmath_t *grad_offset, opmath_t *grad_mask) {
const int num_threads =
(group_channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : group_channels;
const int num_kernels =
......@@ -891,7 +940,7 @@ void dcnv3_col2im_cuda(
num_kernels, grad_col, data_im, data_offset, data_mask,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale, grad_im,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, grad_mask);
} else {
dcnv3_col2im_gpu_kernel_gm<scalar_t>
......@@ -900,7 +949,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
}
} else {
......@@ -912,7 +961,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 2:
......@@ -922,7 +971,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 4:
......@@ -932,7 +981,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 8:
......@@ -942,7 +991,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 16:
......@@ -952,7 +1001,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 32:
......@@ -962,7 +1011,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 64:
......@@ -972,7 +1021,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 128:
......@@ -982,7 +1031,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 256:
......@@ -992,7 +1041,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 512:
......@@ -1002,7 +1051,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
case 1024:
......@@ -1013,7 +1062,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask);
break;
default:
......@@ -1025,7 +1074,7 @@ void dcnv3_col2im_cuda(
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels,
height_in, width_in, height_out, width_out,
offset_scale, grad_im, grad_offset, grad_mask);
offset_scale, remove_center, grad_im, grad_offset, grad_mask);
} else {
dcnv3_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
......@@ -1034,7 +1083,7 @@ void dcnv3_col2im_cuda(
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels,
height_in, width_in, height_out, width_out,
offset_scale, grad_im, grad_offset, grad_mask);
offset_scale, remove_center, grad_im, grad_offset, grad_mask);
}
}
}
......
......@@ -23,13 +23,13 @@ at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int group, const int group_channels,
const float offset_scale, const int im2col_step) {
const float offset_scale, const int im2col_step, const int remove_center) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels,
offset_scale, im2col_step);
offset_scale, im2col_step, remove_center);
#else
AT_ERROR("Not compiled with GPU support");
#endif
......@@ -44,13 +44,13 @@ dcnv3_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels,
const float offset_scale, const at::Tensor &grad_output,
const int im2col_step) {
const int im2col_step, const int remove_center) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels,
offset_scale, grad_output, im2col_step);
offset_scale, grad_output, im2col_step, remove_center);
#else
AT_ERROR("Not compiled with GPU support");
#endif
......
......@@ -19,7 +19,8 @@ from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
H_in, W_in = 8, 8
N, M, D = 2, 4, 16
Kh, Kw = 3, 3
P = Kh * Kw
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
......@@ -42,7 +43,7 @@ def check_forward_equal_with_pytorch_double():
input.double(),
offset.double(),
mask.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu()
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()
im2col_step = 2
output_cuda = DCNv3Function.apply(
......@@ -50,7 +51,7 @@ def check_forward_equal_with_pytorch_double():
offset.double(),
mask.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step).detach().cpu()
im2col_step, remove_center).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
......@@ -72,7 +73,7 @@ def check_forward_equal_with_pytorch_float():
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu()
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()
im2col_step = 2
output_cuda = DCNv3Function.apply(
......@@ -80,7 +81,7 @@ def check_forward_equal_with_pytorch_float():
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step).detach().cpu()
im2col_step, remove_center).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
......@@ -111,7 +112,7 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
input0.double(),
offset0.double(),
mask0.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale)
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
output_pytorch.sum().backward()
input1 = input0.detach()
......@@ -127,7 +128,7 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
offset1.double(),
mask1.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step)
im2col_step, remove_center)
output_cuda.sum().backward()
print(f'>>> backward double: channels {D}')
......@@ -174,7 +175,7 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
input0,
offset0,
mask0,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale)
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
output_pytorch.sum().backward()
input1 = input0.detach()
......@@ -190,7 +191,7 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
offset1,
mask1,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step)
im2col_step, remove_center)
output_cuda.sum().backward()
print(f'>>> backward float: channels {D}')
......@@ -237,7 +238,7 @@ def check_time_cost(im2col_step=128):
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
im2col_step)
im2col_step, remove_center)
torch.cuda.synchronize()
start = time.time()
for i in range(repeat):
......@@ -246,7 +247,7 @@ def check_time_cost(im2col_step=128):
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
im2col_step)
im2col_step, remove_center)
torch.cuda.synchronize()
print(f'foward time cost: {(time.time() - start) / repeat}')
......
......@@ -66,7 +66,27 @@ def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
print('resuming model')
msg = model.load_state_dict(checkpoint['model'], strict=False)
model_checkpoint = checkpoint['model']
if config.MODEL.INTERN_IMAGE.REMOVE_CENTER:
for k, v in model_checkpoint.items():
if 'dcn.mask.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
model_checkpoint[k] = v[idx]
if 'dcn.offset.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
model_checkpoint[k] = v[idx]
if 'dcn.mask.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
model_checkpoint[k] = v[idx, :]
if 'dcn.offset.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
model_checkpoint[k] = v[idx, :]
msg = model.load_state_dict(model_checkpoint, strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
......@@ -239,6 +259,25 @@ def load_pretrained(config, model, logger):
map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
if config.MODEL.INTERN_IMAGE.REMOVE_CENTER:
for k, v in state_dict.items():
if 'dcn.mask.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
state_dict[k] = v[idx]
if 'dcn.offset.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
state_dict[k] = v[idx]
if 'dcn.mask.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
state_dict[k] = v[idx, :]
if 'dcn.offset.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
state_dict[k] = v[idx, :]
msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment