Unverified Commit 8f2d1583 authored by Zeqiang Lai's avatar Zeqiang Lai Committed by GitHub
Browse files

Support removing center point for DCNV3 (#101)

* support remove center point for DCNV3

* dcnv3: update test

* remove center for DCNv3_pytorch

* fix pytorch version

* fix unit test

* add backward compatibility

* bump dcn version to 1.1
parent b64d9ca3
...@@ -79,6 +79,7 @@ _C.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR = False ...@@ -79,6 +79,7 @@ _C.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR = False
_C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM = False _C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM = False
_C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS = None _C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS = None
_C.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE = False _C.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE = False
_C.MODEL.INTERN_IMAGE.REMOVE_CENTER = False
......
...@@ -582,7 +582,7 @@ if __name__ == '__main__': ...@@ -582,7 +582,7 @@ if __name__ == '__main__':
assert has_native_amp, "Please update pytorch(1.6+) to support amp!" assert has_native_amp, "Please update pytorch(1.6+) to support amp!"
# init distributed env # init distributed env
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_NNODES']) != 1: if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1:
print("\nDist init: SLURM") print("\nDist init: SLURM")
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count() gpu = rank % torch.cuda.device_count()
......
...@@ -26,7 +26,8 @@ def build_model(config): ...@@ -26,7 +26,8 @@ def build_model(config):
use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G
level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE # for InternImage-H/G center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
) )
else: else:
raise NotImplementedError(f"Unkown model: {model_type}") raise NotImplementedError(f"Unkown model: {model_type}")
......
...@@ -359,7 +359,9 @@ class InternImageLayer(nn.Module): ...@@ -359,7 +359,9 @@ class InternImageLayer(nn.Module):
with_cp=False, with_cp=False,
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.groups = groups self.groups = groups
...@@ -379,7 +381,9 @@ class InternImageLayer(nn.Module): ...@@ -379,7 +381,9 @@ class InternImageLayer(nn.Module):
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale) # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.drop_path = DropPath(drop_path) if drop_path > 0. \ self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity() else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN') self.norm2 = build_norm_layer(channels, 'LN')
...@@ -463,7 +467,9 @@ class InternImageBlock(nn.Module): ...@@ -463,7 +467,9 @@ class InternImageBlock(nn.Module):
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.depth = depth self.depth = depth
...@@ -487,8 +493,9 @@ class InternImageBlock(nn.Module): ...@@ -487,8 +493,9 @@ class InternImageBlock(nn.Module):
with_cp=with_cp, with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
) for i in range(depth) remove_center = remove_center, # for InternImage-H/G
) for i in range(depth)
]) ])
if not self.post_norm or center_feature_scale: if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN') self.norm = build_norm_layer(channels, 'LN')
...@@ -567,6 +574,7 @@ class InternImage(nn.Module): ...@@ -567,6 +574,7 @@ class InternImage(nn.Module):
level2_post_norm_block_ids=None, # for InternImage-H/G level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
**kwargs): **kwargs):
super().__init__() super().__init__()
self.core_op = core_op self.core_op = core_op
...@@ -579,6 +587,8 @@ class InternImage(nn.Module): ...@@ -579,6 +587,8 @@ class InternImage(nn.Module):
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.remove_center = remove_center
print(f'using core type: {core_op}') print(f'using core type: {core_op}')
print(f'using activation layer: {act_layer}') print(f'using activation layer: {act_layer}')
print(f'using main norm layer: {norm_layer}') print(f'using main norm layer: {norm_layer}')
...@@ -586,6 +596,7 @@ class InternImage(nn.Module): ...@@ -586,6 +596,7 @@ class InternImage(nn.Module):
print(f"level2_post_norm: {level2_post_norm}") print(f"level2_post_norm: {level2_post_norm}")
print(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}") print(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
print(f"res_post_norm: {res_post_norm}") print(f"res_post_norm: {res_post_norm}")
print(f"remove_center: {remove_center}")
in_chans = 3 in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans, self.patch_embed = StemLayer(in_chans=in_chans,
...@@ -623,7 +634,8 @@ class InternImage(nn.Module): ...@@ -623,7 +634,8 @@ class InternImage(nn.Module):
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
) )
self.levels.append(level) self.levels.append(level)
......
...@@ -23,7 +23,7 @@ class DCNv3Function(Function): ...@@ -23,7 +23,7 @@ class DCNv3Function(Function):
ctx, input, offset, mask, ctx, input, offset, mask,
kernel_h, kernel_w, stride_h, stride_w, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale, im2col_step): group, group_channels, offset_scale, im2col_step, remove_center):
ctx.kernel_h = kernel_h ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w ctx.kernel_w = kernel_w
ctx.stride_h = stride_h ctx.stride_h = stride_h
...@@ -36,11 +36,17 @@ class DCNv3Function(Function): ...@@ -36,11 +36,17 @@ class DCNv3Function(Function):
ctx.group_channels = group_channels ctx.group_channels = group_channels
ctx.offset_scale = offset_scale ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
kwargs = {}
if remove_center:
kwargs['remove_center'] = remove_center
output = DCNv3.dcnv3_forward( output = DCNv3.dcnv3_forward(
input, offset, mask, kernel_h, input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step) group_channels, offset_scale, ctx.im2col_step, **kwargs)
ctx.save_for_backward(input, offset, mask) ctx.save_for_backward(input, offset, mask)
return output return output
...@@ -50,20 +56,25 @@ class DCNv3Function(Function): ...@@ -50,20 +56,25 @@ class DCNv3Function(Function):
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors input, offset, mask = ctx.saved_tensors
kwargs = {}
if ctx.remove_center:
kwargs['remove_center'] = ctx.remove_center
grad_input, grad_offset, grad_mask = \ grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward( DCNv3.dcnv3_backward(
input, offset, mask, ctx.kernel_h, input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step) ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step, **kwargs)
return grad_input, grad_offset, grad_mask, \ return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None None, None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod @staticmethod
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, im2col_step): group_channels, offset_scale, im2col_step, remove_center):
"""Symbolic function for mmdeploy::DCNv3. """Symbolic function for mmdeploy::DCNv3.
Returns: Returns:
...@@ -86,6 +97,7 @@ class DCNv3Function(Function): ...@@ -86,6 +97,7 @@ class DCNv3Function(Function):
group_channels_i=int(group_channels), group_channels_i=int(group_channels),
offset_scale_f=float(offset_scale), offset_scale_f=float(offset_scale),
im2col_step_i=int(im2col_step), im2col_step_i=int(im2col_step),
remove_center=int(remove_center),
) )
...@@ -126,14 +138,14 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil ...@@ -126,14 +138,14 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil
x, y = torch.meshgrid( x, y = torch.meshgrid(
torch.linspace( torch.linspace(
-((dilation_w * (kernel_w - 1)) // 2), -((dilation_w * (kernel_w - 1)) // 2),
-((dilation_w * (kernel_w - 1)) // 2) + -((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w,
(kernel_w - 1) * dilation_w, kernel_w, kernel_w,
dtype=torch.float32, dtype=torch.float32,
device=device), device=device),
torch.linspace( torch.linspace(
-((dilation_h * (kernel_h - 1)) // 2), -((dilation_h * (kernel_h - 1)) // 2),
-((dilation_h * (kernel_h - 1)) // 2) + -((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h,
(kernel_h - 1) * dilation_h, kernel_h, kernel_h,
dtype=torch.float32, dtype=torch.float32,
device=device)) device=device))
...@@ -145,13 +157,24 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil ...@@ -145,13 +157,24 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil
return grid return grid
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
idx = list(range(sampling_locations.shape[-2]))
C = (kernel_w * kernel_h - 1)//2
idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0]
sampling_locations = sampling_locations[:,:,:,idx, :]
return sampling_locations
def dcnv3_core_pytorch( def dcnv3_core_pytorch(
input, offset, mask, kernel_h, input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale): group_channels, offset_scale, remove_center):
# for debug and test only, # for debug and test only,
# need to use cuda version instead # need to use cuda version instead
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
raise ValueError('remove_center is only compatible with square odd kernel size.')
input = F.pad( input = F.pad(
input, input,
[0, 0, pad_h, pad_h, pad_w, pad_w]) [0, 0, pad_h, pad_h, pad_w, pad_w])
...@@ -163,12 +186,15 @@ def dcnv3_core_pytorch( ...@@ -163,12 +186,15 @@ def dcnv3_core_pytorch(
grid = _generate_dilation_grids( grid = _generate_dilation_grids(
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device)
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
offset * offset_scale / spatial_norm if remove_center:
sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
sampling_locations = sampling_locations.flatten(3, 4)
sampling_locations = sampling_locations + offset * offset_scale / spatial_norm
P_ = kernel_h * kernel_w P_ = kernel_h * kernel_w - remove_center
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
......
...@@ -101,7 +101,9 @@ class DCNv3_pytorch(nn.Module): ...@@ -101,7 +101,9 @@ class DCNv3_pytorch(nn.Module):
offset_scale=1.0, offset_scale=1.0,
act_layer='GELU', act_layer='GELU',
norm_layer='LN', norm_layer='LN',
center_feature_scale=False): center_feature_scale=False,
remove_center=False,
):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
...@@ -137,6 +139,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -137,6 +139,7 @@ class DCNv3_pytorch(nn.Module):
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
...@@ -154,10 +157,10 @@ class DCNv3_pytorch(nn.Module): ...@@ -154,10 +157,10 @@ class DCNv3_pytorch(nn.Module):
build_act_layer(act_layer)) build_act_layer(act_layer))
self.offset = nn.Linear( self.offset = nn.Linear(
channels, channels,
group * kernel_size * kernel_size * 2) group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear( self.mask = nn.Linear(
channels, channels,
group * kernel_size * kernel_size) group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
...@@ -202,7 +205,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -202,7 +205,7 @@ class DCNv3_pytorch(nn.Module):
self.pad, self.pad, self.pad, self.pad,
self.dilation, self.dilation, self.dilation, self.dilation,
self.group, self.group_channels, self.group, self.group_channels,
self.offset_scale) self.offset_scale, self.remove_center)
if self.center_feature_scale: if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module( center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
...@@ -217,18 +220,20 @@ class DCNv3_pytorch(nn.Module): ...@@ -217,18 +220,20 @@ class DCNv3_pytorch(nn.Module):
class DCNv3(nn.Module): class DCNv3(nn.Module):
def __init__( def __init__(
self, self,
channels=64, channels=64,
kernel_size=3, kernel_size=3,
dw_kernel_size=None, dw_kernel_size=None,
stride=1, stride=1,
pad=1, pad=1,
dilation=1, dilation=1,
group=4, group=4,
offset_scale=1.0, offset_scale=1.0,
act_layer='GELU', act_layer='GELU',
norm_layer='LN', norm_layer='LN',
center_feature_scale=False): center_feature_scale=False,
remove_center=False,
):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
...@@ -264,7 +269,11 @@ class DCNv3(nn.Module): ...@@ -264,7 +269,11 @@ class DCNv3(nn.Module):
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
if self.remove_center and self.kernel_size % 2 == 0:
raise ValueError('remove_center is only compatible with odd kernel size.')
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
channels, channels,
...@@ -281,10 +290,10 @@ class DCNv3(nn.Module): ...@@ -281,10 +290,10 @@ class DCNv3(nn.Module):
build_act_layer(act_layer)) build_act_layer(act_layer))
self.offset = nn.Linear( self.offset = nn.Linear(
channels, channels,
group * kernel_size * kernel_size * 2) group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear( self.mask = nn.Linear(
channels, channels,
group * kernel_size * kernel_size) group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
...@@ -321,8 +330,9 @@ class DCNv3(nn.Module): ...@@ -321,8 +330,9 @@ class DCNv3(nn.Module):
x1 = self.dw_conv(x1) x1 = self.dw_conv(x1)
offset = self.offset(x1) offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1) mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) mask = F.softmax(mask, -1)
mask = mask.reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply( x = DCNv3Function.apply(
x, offset, mask, x, offset, mask,
self.kernel_size, self.kernel_size, self.kernel_size, self.kernel_size,
...@@ -331,7 +341,8 @@ class DCNv3(nn.Module): ...@@ -331,7 +341,8 @@ class DCNv3(nn.Module):
self.dilation, self.dilation, self.dilation, self.dilation,
self.group, self.group_channels, self.group, self.group_channels,
self.offset_scale, self.offset_scale,
256) 256,
self.remove_center)
if self.center_feature_scale: if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module( center_feature_scale = self.center_feature_scale_module(
......
...@@ -61,7 +61,7 @@ def get_extensions(): ...@@ -61,7 +61,7 @@ def get_extensions():
setup( setup(
name="DCNv3", name="DCNv3",
version="1.0", version="1.1",
author="InternImage", author="InternImage",
url="https://github.com/OpenGVLab/InternImage", url="https://github.com/OpenGVLab/InternImage",
description= description=
......
...@@ -25,7 +25,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, ...@@ -25,7 +25,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int dilation_w, const int group,
const int group_channels, const int group_channels,
const float offset_scale, const int im2col_step) { const float offset_scale, const int im2col_step, const int remove_center) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
...@@ -61,8 +61,8 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, ...@@ -61,8 +61,8 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
width_out, group * group_channels}); width_out, group * group_channels});
auto per_input_size = height_in * width_in * group * group_channels; auto per_input_size = height_in * width_in * group * group_channels;
auto per_offset_size = auto per_offset_size =
height_out * width_out * group * kernel_h * kernel_w * 2; height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center);
for (int n = 0; n < batch / im2col_step_; ++n) { for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n); auto columns = output_n.select(0, n);
// AT_DISPATCH_FLOATING_TYPES( // AT_DISPATCH_FLOATING_TYPES(
...@@ -77,7 +77,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, ...@@ -77,7 +77,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
columns.data<scalar_t>(), kernel_h, kernel_w, stride_h, columns.data<scalar_t>(), kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, batch_n, height_in, width_in, height_out, group_channels, batch_n, height_in, width_in, height_out,
width_out, offset_scale); width_out, offset_scale, remove_center);
})); }));
} }
...@@ -91,7 +91,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -91,7 +91,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_h, const int pad_w, const int dilation_h, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int dilation_w, const int group,
const int group_channels, const float offset_scale, const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step) { const at::Tensor &grad_output, const int im2col_step, const int remove_center) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
...@@ -135,8 +135,8 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -135,8 +135,8 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const int batch_n = im2col_step_; const int batch_n = im2col_step_;
auto per_input_size = height_in * width_in * group * group_channels; auto per_input_size = height_in * width_in * group * group_channels;
auto per_offset_size = auto per_offset_size =
height_out * width_out * group * kernel_h * kernel_w * 2; height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center);
auto grad_output_n = auto grad_output_n =
grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, grad_output.view({batch / im2col_step_, batch_n, height_out * width_out,
group, group_channels}); group, group_channels});
...@@ -155,7 +155,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -155,7 +155,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size, mask.data<scalar_t>() + n * im2col_step_ * per_mask_size,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, batch_n, dilation_h, dilation_w, group, group_channels, batch_n,
height_in, width_in, height_out, width_out, offset_scale, height_in, width_in, height_out, width_out, offset_scale, remove_center,
grad_input.data<opmath_t>() + grad_input.data<opmath_t>() +
n * im2col_step_ * per_input_size, n * im2col_step_ * per_input_size,
grad_offset.data<opmath_t>() + grad_offset.data<opmath_t>() +
......
...@@ -19,7 +19,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, ...@@ -19,7 +19,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int dilation_w, const int group,
const int group_channels, const int group_channels,
const float offset_scale, const int im2col_step); const float offset_scale, const int im2col_step, const int remove_center);
std::vector<at::Tensor> std::vector<at::Tensor>
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
...@@ -28,4 +28,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -28,4 +28,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_h, const int pad_w, const int dilation_h, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int dilation_w, const int group,
const int group_channels, const float offset_scale, const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step); const at::Tensor &grad_output, const int im2col_step, const int remove_center);
...@@ -221,7 +221,7 @@ __global__ void dcnv3_im2col_gpu_kernel( ...@@ -221,7 +221,7 @@ __global__ void dcnv3_im2col_gpu_kernel(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale) { const opmath_t offset_scale, const int remove_center) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index; int _temp = index;
const int c_col = _temp % group_channels; const int c_col = _temp % group_channels;
...@@ -239,7 +239,7 @@ __global__ void dcnv3_im2col_gpu_kernel( ...@@ -239,7 +239,7 @@ __global__ void dcnv3_im2col_gpu_kernel(
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
scalar_t *data_col_ptr = data_col + index; scalar_t *data_col_ptr = data_col + index;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = group * group_channels; const int qid_stride = group * group_channels;
...@@ -250,24 +250,31 @@ __global__ void dcnv3_im2col_gpu_kernel( ...@@ -250,24 +250,31 @@ __global__ void dcnv3_im2col_gpu_kernel(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && p0_h_ + (j * dilation_h + offset_h) * offset_scale;
loc_w < width_in) { const opmath_t weight = data_mask[data_weight_ptr];
col += dcnv3_im2col_bilinear( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, loc_w < width_in) {
group_channels, loc_h, loc_w, g_col, c_col) * col += dcnv3_im2col_bilinear(
weight; data_im_ptr, height_in, width_in, group,
group_channels, loc_h, loc_w, g_col, c_col) *
weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
} }
data_weight_ptr += 1;
data_loc_w_ptr += 2;
} }
} }
*data_col_ptr = col; *data_col_ptr = col;
...@@ -283,7 +290,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( ...@@ -283,7 +290,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) { opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
__shared__ opmath_t cache_grad_offset[blockSize * 2]; __shared__ opmath_t cache_grad_offset[blockSize * 2];
...@@ -305,7 +312,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( ...@@ -305,7 +312,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const opmath_t top_grad = grad_col[index]; const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
...@@ -319,51 +326,58 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( ...@@ -319,51 +326,58 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
*(cache_grad_offset + (threadIdx.x << 1)) = 0; p0_h_ + (j * dilation_h + offset_h) * offset_scale;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_mask + threadIdx.x) = 0; *(cache_grad_offset + (threadIdx.x << 1)) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
loc_w < width_in) { *(cache_grad_mask + threadIdx.x) = 0;
dcnv3_col2im_bilinear( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, group_channels, loc_w < width_in) {
loc_h, loc_w, g_col, c_col, offset_scale, top_grad, dcnv3_col2im_bilinear(
weight, grad_im_ptr, data_im_ptr, height_in, width_in, group, group_channels,
cache_grad_offset + (threadIdx.x << 1), loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
cache_grad_mask + threadIdx.x); weight, grad_im_ptr,
} cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads(); __syncthreads();
if (tid == 0) { if (tid == 0) {
opmath_t _grad_w = cache_grad_offset[0], opmath_t _grad_w = cache_grad_offset[0],
_grad_h = cache_grad_offset[1], _grad_h = cache_grad_offset[1],
_grad_a = cache_grad_mask[0]; _grad_a = cache_grad_mask[0];
int sid = 2; int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) { for (unsigned int tid = 1; tid < blockSize; ++tid) {
_grad_w += cache_grad_offset[sid]; _grad_w += cache_grad_offset[sid];
_grad_h += cache_grad_offset[sid + 1]; _grad_h += cache_grad_offset[sid + 1];
_grad_a += cache_grad_mask[tid]; _grad_a += cache_grad_mask[tid];
sid += 2; sid += 2;
}
*grad_offset = _grad_w;
*(grad_offset + 1) = _grad_h;
*grad_mask = _grad_a;
} }
__syncthreads();
*grad_offset = _grad_w; data_weight_ptr += 1;
*(grad_offset + 1) = _grad_h; data_loc_w_ptr += 2;
*grad_mask = _grad_a; grad_mask += 1;
grad_offset += 2;
} }
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
} }
} }
...@@ -377,7 +391,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( ...@@ -377,7 +391,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) { opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
__shared__ opmath_t cache_grad_offset[blockSize * 2]; __shared__ opmath_t cache_grad_offset[blockSize * 2];
...@@ -399,7 +413,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( ...@@ -399,7 +413,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const opmath_t top_grad = grad_col[index]; const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
...@@ -413,53 +427,60 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( ...@@ -413,53 +427,60 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
*(cache_grad_offset + (threadIdx.x << 1)) = 0; p0_h_ + (j * dilation_h + offset_h) * offset_scale;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_mask + threadIdx.x) = 0; *(cache_grad_offset + (threadIdx.x << 1)) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
loc_w < width_in) { *(cache_grad_mask + threadIdx.x) = 0;
dcnv3_col2im_bilinear( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, group_channels, loc_w < width_in) {
loc_h, loc_w, g_col, c_col, offset_scale, top_grad, dcnv3_col2im_bilinear(
weight, grad_im_ptr, data_im_ptr, height_in, width_in, group, group_channels,
cache_grad_offset + (threadIdx.x << 1), loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
cache_grad_mask + threadIdx.x); weight, grad_im_ptr,
} cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
__syncthreads(); for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
}
__syncthreads();
}
for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { if (tid == 0) {
if (tid < s) { *grad_offset = cache_grad_offset[0];
const unsigned int xid1 = tid << 1; *(grad_offset + 1) = cache_grad_offset[1];
const unsigned int xid2 = (tid + s) << 1; *grad_mask = cache_grad_mask[0];
cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
} }
__syncthreads(); __syncthreads();
}
if (tid == 0) { data_weight_ptr += 1;
*grad_offset = cache_grad_offset[0]; data_loc_w_ptr += 2;
*(grad_offset + 1) = cache_grad_offset[1]; grad_mask += 1;
*grad_mask = cache_grad_mask[0]; grad_offset += 2;
} }
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
} }
} }
...@@ -473,7 +494,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1( ...@@ -473,7 +494,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) { opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[]; extern __shared__ int _s[];
...@@ -496,7 +517,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1( ...@@ -496,7 +517,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
const opmath_t top_grad = grad_col[index]; const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
...@@ -510,51 +531,58 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1( ...@@ -510,51 +531,58 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
*(cache_grad_offset + (threadIdx.x << 1)) = 0; p0_h_ + (j * dilation_h + offset_h) * offset_scale;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_mask + threadIdx.x) = 0; *(cache_grad_offset + (threadIdx.x << 1)) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
loc_w < width_in) { *(cache_grad_mask + threadIdx.x) = 0;
dcnv3_col2im_bilinear( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, group_channels, loc_w < width_in) {
loc_h, loc_w, g_col, c_col, offset_scale, top_grad, dcnv3_col2im_bilinear(
weight, grad_im_ptr, data_im_ptr, height_in, width_in, group, group_channels,
cache_grad_offset + (threadIdx.x << 1), loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
cache_grad_mask + threadIdx.x); weight, grad_im_ptr,
} cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
opmath_t _grad_w = cache_grad_offset[0],
_grad_h = cache_grad_offset[1],
_grad_a = cache_grad_mask[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
_grad_w += cache_grad_offset[sid];
_grad_h += cache_grad_offset[sid + 1];
_grad_a += cache_grad_mask[tid];
sid += 2;
}
__syncthreads(); *grad_offset = _grad_w;
if (tid == 0) { *(grad_offset + 1) = _grad_h;
opmath_t _grad_w = cache_grad_offset[0], *grad_mask = _grad_a;
_grad_h = cache_grad_offset[1],
_grad_a = cache_grad_mask[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
_grad_w += cache_grad_offset[sid];
_grad_h += cache_grad_offset[sid + 1];
_grad_a += cache_grad_mask[tid];
sid += 2;
} }
__syncthreads();
*grad_offset = _grad_w; data_weight_ptr += 1;
*(grad_offset + 1) = _grad_h; data_loc_w_ptr += 2;
*grad_mask = _grad_a; grad_mask += 1;
grad_offset += 2;
} }
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
} }
} }
...@@ -568,7 +596,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2( ...@@ -568,7 +596,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) { opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[]; extern __shared__ int _s[];
...@@ -591,7 +619,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2( ...@@ -591,7 +619,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
const opmath_t top_grad = grad_col[index]; const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
...@@ -605,62 +633,69 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2( ...@@ -605,62 +633,69 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
*(cache_grad_offset + (threadIdx.x << 1)) = 0; p0_h_ + (j * dilation_h + offset_h) * offset_scale;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_mask + threadIdx.x) = 0; *(cache_grad_offset + (threadIdx.x << 1)) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
loc_w < width_in) { *(cache_grad_mask + threadIdx.x) = 0;
dcnv3_col2im_bilinear( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, group_channels, loc_w < width_in) {
loc_h, loc_w, g_col, c_col, offset_scale, top_grad, dcnv3_col2im_bilinear(
weight, grad_im_ptr, data_im_ptr, height_in, width_in, group, group_channels,
cache_grad_offset + (threadIdx.x << 1), loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
cache_grad_mask + threadIdx.x); weight, grad_im_ptr,
} cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
__syncthreads(); for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; if (tid < s) {
s >>= 1, spre >>= 1) { const unsigned int xid1 = tid << 1;
if (tid < s) { const unsigned int xid2 = (tid + s) << 1;
const unsigned int xid1 = tid << 1; cache_grad_mask[tid] += cache_grad_mask[tid + s];
const unsigned int xid2 = (tid + s) << 1; cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_mask[tid] +=
cache_grad_mask[tid + (s << 1)];
cache_grad_offset[xid1] +=
cache_grad_offset[xid2 + (s << 1)];
cache_grad_offset[xid1 + 1] += cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1 + (s << 1)]; cache_grad_offset[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_mask[tid] +=
cache_grad_mask[tid + (s << 1)];
cache_grad_offset[xid1] +=
cache_grad_offset[xid2 + (s << 1)];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1 + (s << 1)];
}
} }
__syncthreads();
}
if (tid == 0) {
*grad_offset = cache_grad_offset[0];
*(grad_offset + 1) = cache_grad_offset[1];
*grad_mask = cache_grad_mask[0];
} }
__syncthreads(); __syncthreads();
}
if (tid == 0) { data_weight_ptr += 1;
*grad_offset = cache_grad_offset[0]; data_loc_w_ptr += 2;
*(grad_offset + 1) = cache_grad_offset[1]; grad_mask += 1;
*grad_mask = cache_grad_mask[0]; grad_offset += 2;
} }
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
} }
} }
...@@ -674,7 +709,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( ...@@ -674,7 +709,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) { opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[]; extern __shared__ int _s[];
...@@ -697,7 +732,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( ...@@ -697,7 +732,7 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const opmath_t top_grad = grad_col[index]; const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
...@@ -711,62 +746,69 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( ...@@ -711,62 +746,69 @@ __global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
*(cache_grad_offset + (threadIdx.x << 1)) = 0; p0_h_ + (j * dilation_h + offset_h) * offset_scale;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_mask + threadIdx.x) = 0; *(cache_grad_offset + (threadIdx.x << 1)) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
loc_w < width_in) { *(cache_grad_mask + threadIdx.x) = 0;
dcnv3_col2im_bilinear( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, group_channels, loc_w < width_in) {
loc_h, loc_w, g_col, c_col, offset_scale, top_grad, dcnv3_col2im_bilinear(
weight, grad_im_ptr, data_im_ptr, height_in, width_in, group, group_channels,
cache_grad_offset + (threadIdx.x << 1), loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
cache_grad_mask + threadIdx.x); weight, grad_im_ptr,
} cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads(); __syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) { s >>= 1, spre >>= 1) {
if (tid < s) { if (tid < s) {
const unsigned int xid1 = tid << 1; const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1; const unsigned int xid2 = (tid + s) << 1;
cache_grad_mask[tid] += cache_grad_mask[tid + s]; cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2]; cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_mask[tid] +=
cache_grad_mask[tid + (s << 1)];
cache_grad_offset[xid1] +=
cache_grad_offset[xid2 + (s << 1)];
cache_grad_offset[xid1 + 1] += cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1 + (s << 1)]; cache_grad_offset[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_mask[tid] +=
cache_grad_mask[tid + (s << 1)];
cache_grad_offset[xid1] +=
cache_grad_offset[xid2 + (s << 1)];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1 + (s << 1)];
}
} }
__syncthreads();
}
if (tid == 0) {
atomicAdd(grad_offset, cache_grad_offset[0]);
atomicAdd(grad_offset + 1, cache_grad_offset[1]);
atomicAdd(grad_mask, cache_grad_mask[0]);
} }
__syncthreads(); __syncthreads();
}
if (tid == 0) { data_weight_ptr += 1;
atomicAdd(grad_offset, cache_grad_offset[0]); data_loc_w_ptr += 2;
atomicAdd(grad_offset + 1, cache_grad_offset[1]); grad_mask += 1;
atomicAdd(grad_mask, cache_grad_mask[0]); grad_offset += 2;
} }
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
} }
} }
...@@ -780,7 +822,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm( ...@@ -780,7 +822,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) { opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) { CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index; int _temp = index;
...@@ -799,7 +841,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm( ...@@ -799,7 +841,7 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
const opmath_t top_grad = grad_col[index]; const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in; const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w; const int kernel_size = kernel_h * kernel_w - remove_center;
int data_weight_ptr = sampling_index * kernel_size; int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
...@@ -813,26 +855,33 @@ __global__ void dcnv3_col2im_gpu_kernel_gm( ...@@ -813,26 +855,33 @@ __global__ void dcnv3_col2im_gpu_kernel_gm(
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ = const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
for (int i = 0; i < kernel_w; ++i) { for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) { for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr]; // if not remove center, or remove center and not the center
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; if (i!=center_w || j!=center_h || !remove_center) {
const opmath_t loc_w = const opmath_t offset_w = data_offset[data_loc_w_ptr];
p0_w_ + (i * dilation_w + offset_w) * offset_scale; const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_h = const opmath_t loc_w =
p0_h_ + (j * dilation_h + offset_h) * offset_scale; p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr]; const opmath_t loc_h =
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && p0_h_ + (j * dilation_h + offset_h) * offset_scale;
loc_w < width_in) { const opmath_t weight = data_mask[data_weight_ptr];
dcnv3_col2im_bilinear_gm( if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
data_im_ptr, height_in, width_in, group, group_channels, loc_w < width_in) {
loc_h, loc_w, g_col, c_col, offset_scale, top_grad, dcnv3_col2im_bilinear_gm(
weight, grad_im_ptr, grad_offset, grad_mask); data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr, grad_offset, grad_mask);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
} }
} }
} }
...@@ -848,7 +897,7 @@ void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im, ...@@ -848,7 +897,7 @@ void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im,
const int group, const int group_channels, const int group, const int group_channels,
const int batch_n, const int height_in, const int batch_n, const int height_in,
const int width_in, const int height_out, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale) { const int width_out, const opmath_t offset_scale, const int remove_center) {
const int num_kernels = const int num_kernels =
batch_n * height_out * width_out * group * group_channels; batch_n * height_out * width_out * group * group_channels;
const int num_actual_kernels = const int num_actual_kernels =
...@@ -859,7 +908,7 @@ void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im, ...@@ -859,7 +908,7 @@ void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im,
stream>>>(num_kernels, data_im, data_offset, data_mask, data_col, stream>>>(num_kernels, data_im, data_offset, data_mask, data_col,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, height_in, dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale); width_in, height_out, width_out, offset_scale, remove_center);
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
...@@ -875,8 +924,8 @@ void dcnv3_col2im_cuda( ...@@ -875,8 +924,8 @@ void dcnv3_col2im_cuda(
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int batch_n, const int group, const int group_channels, const int batch_n,
const int height_in, const int width_in, const int height_out, const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, opmath_t *grad_im, const int width_out, const opmath_t offset_scale, const int remove_center,
opmath_t *grad_offset, opmath_t *grad_mask) { opmath_t *grad_im, opmath_t *grad_offset, opmath_t *grad_mask) {
const int num_threads = const int num_threads =
(group_channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : group_channels; (group_channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : group_channels;
const int num_kernels = const int num_kernels =
...@@ -891,7 +940,7 @@ void dcnv3_col2im_cuda( ...@@ -891,7 +940,7 @@ void dcnv3_col2im_cuda(
num_kernels, grad_col, data_im, data_offset, data_mask, num_kernels, grad_col, data_im, data_offset, data_mask,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, height_in, dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale, grad_im, width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, grad_mask); grad_offset, grad_mask);
} else { } else {
dcnv3_col2im_gpu_kernel_gm<scalar_t> dcnv3_col2im_gpu_kernel_gm<scalar_t>
...@@ -900,7 +949,7 @@ void dcnv3_col2im_cuda( ...@@ -900,7 +949,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
} }
} else { } else {
...@@ -912,7 +961,7 @@ void dcnv3_col2im_cuda( ...@@ -912,7 +961,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 2: case 2:
...@@ -922,7 +971,7 @@ void dcnv3_col2im_cuda( ...@@ -922,7 +971,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 4: case 4:
...@@ -932,7 +981,7 @@ void dcnv3_col2im_cuda( ...@@ -932,7 +981,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 8: case 8:
...@@ -942,7 +991,7 @@ void dcnv3_col2im_cuda( ...@@ -942,7 +991,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 16: case 16:
...@@ -952,7 +1001,7 @@ void dcnv3_col2im_cuda( ...@@ -952,7 +1001,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 32: case 32:
...@@ -962,7 +1011,7 @@ void dcnv3_col2im_cuda( ...@@ -962,7 +1011,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 64: case 64:
...@@ -972,7 +1021,7 @@ void dcnv3_col2im_cuda( ...@@ -972,7 +1021,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 128: case 128:
...@@ -982,7 +1031,7 @@ void dcnv3_col2im_cuda( ...@@ -982,7 +1031,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 256: case 256:
...@@ -992,7 +1041,7 @@ void dcnv3_col2im_cuda( ...@@ -992,7 +1041,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 512: case 512:
...@@ -1002,7 +1051,7 @@ void dcnv3_col2im_cuda( ...@@ -1002,7 +1051,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
case 1024: case 1024:
...@@ -1013,7 +1062,7 @@ void dcnv3_col2im_cuda( ...@@ -1013,7 +1062,7 @@ void dcnv3_col2im_cuda(
data_mask, kernel_h, kernel_w, stride_h, stride_w, data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out, group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset, width_out, offset_scale, remove_center, grad_im, grad_offset,
grad_mask); grad_mask);
break; break;
default: default:
...@@ -1025,7 +1074,7 @@ void dcnv3_col2im_cuda( ...@@ -1025,7 +1074,7 @@ void dcnv3_col2im_cuda(
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, dilation_h, dilation_w, group, group_channels,
height_in, width_in, height_out, width_out, height_in, width_in, height_out, width_out,
offset_scale, grad_im, grad_offset, grad_mask); offset_scale, remove_center, grad_im, grad_offset, grad_mask);
} else { } else {
dcnv3_col2im_gpu_kernel_shm_reduce_v2<scalar_t> dcnv3_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
...@@ -1034,7 +1083,7 @@ void dcnv3_col2im_cuda( ...@@ -1034,7 +1083,7 @@ void dcnv3_col2im_cuda(
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, dilation_h, dilation_w, group, group_channels,
height_in, width_in, height_out, width_out, height_in, width_in, height_out, width_out,
offset_scale, grad_im, grad_offset, grad_mask); offset_scale, remove_center, grad_im, grad_offset, grad_mask);
} }
} }
} }
......
...@@ -23,13 +23,13 @@ at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, ...@@ -23,13 +23,13 @@ at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset,
const int stride_w, const int pad_h, const int pad_w, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int group, const int group_channels,
const float offset_scale, const int im2col_step) { const float offset_scale, const int im2col_step, const int remove_center) {
if (input.type().is_cuda()) { if (input.type().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w, return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels, dilation_w, group, group_channels,
offset_scale, im2col_step); offset_scale, im2col_step, remove_center);
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");
#endif #endif
...@@ -44,13 +44,13 @@ dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -44,13 +44,13 @@ dcnv3_backward(const at::Tensor &input, const at::Tensor &offset,
const int pad_w, const int dilation_h, const int dilation_w, const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int group, const int group_channels,
const float offset_scale, const at::Tensor &grad_output, const float offset_scale, const at::Tensor &grad_output,
const int im2col_step) { const int im2col_step, const int remove_center) {
if (input.type().is_cuda()) { if (input.type().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w, return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels, dilation_w, group, group_channels,
offset_scale, grad_output, im2col_step); offset_scale, grad_output, im2col_step, remove_center);
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");
#endif #endif
......
...@@ -19,7 +19,8 @@ from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch ...@@ -19,7 +19,8 @@ from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
H_in, W_in = 8, 8 H_in, W_in = 8, 8
N, M, D = 2, 4, 16 N, M, D = 2, 4, 16
Kh, Kw = 3, 3 Kh, Kw = 3, 3
P = Kh * Kw remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0 offset_scale = 2.0
pad = 1 pad = 1
dilation = 1 dilation = 1
...@@ -42,7 +43,7 @@ def check_forward_equal_with_pytorch_double(): ...@@ -42,7 +43,7 @@ def check_forward_equal_with_pytorch_double():
input.double(), input.double(),
offset.double(), offset.double(),
mask.double(), mask.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu() Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()
im2col_step = 2 im2col_step = 2
output_cuda = DCNv3Function.apply( output_cuda = DCNv3Function.apply(
...@@ -50,7 +51,7 @@ def check_forward_equal_with_pytorch_double(): ...@@ -50,7 +51,7 @@ def check_forward_equal_with_pytorch_double():
offset.double(), offset.double(),
mask.double(), mask.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step).detach().cpu() im2col_step, remove_center).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch) fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max() max_abs_err = (output_cuda - output_pytorch).abs().max()
...@@ -72,7 +73,7 @@ def check_forward_equal_with_pytorch_float(): ...@@ -72,7 +73,7 @@ def check_forward_equal_with_pytorch_float():
input, input,
offset, offset,
mask, mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu() Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()
im2col_step = 2 im2col_step = 2
output_cuda = DCNv3Function.apply( output_cuda = DCNv3Function.apply(
...@@ -80,7 +81,7 @@ def check_forward_equal_with_pytorch_float(): ...@@ -80,7 +81,7 @@ def check_forward_equal_with_pytorch_float():
offset, offset,
mask, mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step).detach().cpu() im2col_step, remove_center).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max() max_abs_err = (output_cuda - output_pytorch).abs().max()
...@@ -111,7 +112,7 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o ...@@ -111,7 +112,7 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
input0.double(), input0.double(),
offset0.double(), offset0.double(),
mask0.double(), mask0.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale) Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
output_pytorch.sum().backward() output_pytorch.sum().backward()
input1 = input0.detach() input1 = input0.detach()
...@@ -127,7 +128,7 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o ...@@ -127,7 +128,7 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
offset1.double(), offset1.double(),
mask1.double(), mask1.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step) im2col_step, remove_center)
output_cuda.sum().backward() output_cuda.sum().backward()
print(f'>>> backward double: channels {D}') print(f'>>> backward double: channels {D}')
...@@ -174,7 +175,7 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of ...@@ -174,7 +175,7 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
input0, input0,
offset0, offset0,
mask0, mask0,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale) Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
output_pytorch.sum().backward() output_pytorch.sum().backward()
input1 = input0.detach() input1 = input0.detach()
...@@ -190,7 +191,7 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of ...@@ -190,7 +191,7 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
offset1, offset1,
mask1, mask1,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step) im2col_step, remove_center)
output_cuda.sum().backward() output_cuda.sum().backward()
print(f'>>> backward float: channels {D}') print(f'>>> backward float: channels {D}')
...@@ -237,7 +238,7 @@ def check_time_cost(im2col_step=128): ...@@ -237,7 +238,7 @@ def check_time_cost(im2col_step=128):
offset, offset,
mask, mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
im2col_step) im2col_step, remove_center)
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
for i in range(repeat): for i in range(repeat):
...@@ -246,7 +247,7 @@ def check_time_cost(im2col_step=128): ...@@ -246,7 +247,7 @@ def check_time_cost(im2col_step=128):
offset, offset,
mask, mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
im2col_step) im2col_step, remove_center)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'foward time cost: {(time.time() - start) / repeat}') print(f'foward time cost: {(time.time() - start) / repeat}')
......
...@@ -66,7 +66,27 @@ def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): ...@@ -66,7 +66,27 @@ def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
print('resuming model') print('resuming model')
msg = model.load_state_dict(checkpoint['model'], strict=False)
model_checkpoint = checkpoint['model']
if config.MODEL.INTERN_IMAGE.REMOVE_CENTER:
for k, v in model_checkpoint.items():
if 'dcn.mask.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
model_checkpoint[k] = v[idx]
if 'dcn.offset.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
model_checkpoint[k] = v[idx]
if 'dcn.mask.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
model_checkpoint[k] = v[idx, :]
if 'dcn.offset.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
model_checkpoint[k] = v[idx, :]
msg = model.load_state_dict(model_checkpoint, strict=False)
logger.info(msg) logger.info(msg)
max_accuracy = 0.0 max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
...@@ -239,6 +259,25 @@ def load_pretrained(config, model, logger): ...@@ -239,6 +259,25 @@ def load_pretrained(config, model, logger):
map22kto1k, :] map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
if config.MODEL.INTERN_IMAGE.REMOVE_CENTER:
for k, v in state_dict.items():
if 'dcn.mask.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
state_dict[k] = v[idx]
if 'dcn.offset.bias' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
state_dict[k] = v[idx]
if 'dcn.mask.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 4 and (i - 4) % 9 != 0]
state_dict[k] = v[idx, :]
if 'dcn.offset.weight' in k:
idx = list(range(v.shape[0]))
idx = [i for i in idx if i != 8 and (i - 8) % 18 != 0 and i != 9 and (i - 9) % 18 != 0]
state_dict[k] = v[idx, :]
msg = model.load_state_dict(state_dict, strict=False) msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg) logger.warning(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment