Commit 3ba75d4c authored by Zhe Chen's avatar Zhe Chen Committed by zhe chen
Browse files

[release] Release InternImage-H/G (#34)

* update internimage

* Update README.md
parent 3a37b813
......@@ -176,9 +176,10 @@
| InternImage-B | ImageNet-1K | 224x224 | 84.9 | 97M | 16G | - | [ckpt](https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_b_1k_224.pth) \| [cfg](classification/configs/internimage_b_1k_224.yaml) |
| InternImage-L | ImageNet-22K | 384x384 | 87.7 | 223M | 108G | [ckpt](https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_l_22k_192to384.pth) | [ckpt](https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_l_22kto1k_384.pth) \| [cfg](classification/configs/internimage_l_22kto1k_384.yaml) |
| InternImage-XL | ImageNet-22K | 384x384 | 88.0 | 335M | 163G | [ckpt](https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_xl_22k_192to384.pth) | [ckpt](https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_xl_22kto1k_384.pth) \| [cfg](classification/configs/internimage_xl_22kto1k_384.yaml) |
| InternImage-H | Joint 427M | 224x224 | 88.9 | 1.08B | 188G | TBD | [ckpt](https://pan.baidu.com/s/1R3niTRjrERUet2xGc6ePPA) \| [cfg](classification/configs/internimage_h_jointto1k_224.yaml) |
| InternImage-H | Joint 427M | 640x640 | 89.6 | 1.08B | 1478G |TBD | [ckpt](https://pan.baidu.com/s/1R3niTRjrERUet2xGc6ePPA) \| [cfg](classification/configs/internimage_h_jointto1k_640.yaml) |
| InternImage-G | Joint 427M | 512x512 | 90.1 | 3B | - | TBD | [ckpt](https://pan.baidu.com/s/1R3niTRjrERUet2xGc6ePPA) \| [cfg](classification/configs/internimage_g_jointto1k_512.yaml)|
| InternImage-H | Joint 427M | 224x224 | 88.9 | 1.08B | 188G | - | [ckpt](https://pan.baidu.com/s/1R3niTRjrERUet2xGc6ePPA) \| [cfg](classification/configs/internimage_h_jointto1k_224.yaml) |
| InternImage-H | Joint 427M | 640x640 | 89.6 | 1.08B | 1478G | - | [ckpt](https://pan.baidu.com/s/1R3niTRjrERUet2xGc6ePPA) \| [cfg](classification/configs/internimage_h_jointto1k_640.yaml) |
| InternImage-G | Joint 427M | 512x512 | 90.1 | 3B | - | - | [ckpt](https://pan.baidu.com/s/1R3niTRjrERUet2xGc6ePPA) \| [cfg](classification/configs/internimage_g_jointto1k_512.yaml) |
- Extraction code for downloading InternImage-H/G: 2vwu
**COCO目标检测和实例分割**
......
......@@ -73,6 +73,14 @@ _C.MODEL.INTERN_IMAGE.OFFSET_SCALE = 1.0
_C.MODEL.INTERN_IMAGE.MLP_RATIO = 4.0
_C.MODEL.INTERN_IMAGE.CORE_OP = 'DCNv3'
_C.MODEL.INTERN_IMAGE.POST_NORM = False
_C.MODEL.INTERN_IMAGE.RES_POST_NORM = False
_C.MODEL.INTERN_IMAGE.DW_KERNEL_SIZE = None
_C.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR = False
_C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM = False
_C.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS = None
_C.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE = False
# -----------------------------------------------------------------------------
# Training settings
......
DATA:
IMG_SIZE: 512
IMG_ON_MEMORY: True
AUG:
MIXUP: 0.0
CUTMIX: 0.0
REPROB: 0.0
MODEL:
TYPE: intern_image
DROP_PATH_RATE: 0.4
LABEL_SMOOTHING: 0.3
INTERN_IMAGE:
CORE_OP: 'DCNv3'
DEPTHS: [2, 2, 48, 4]
GROUPS: [16, 32, 64, 128]
CHANNELS: 512
DW_KERNEL_SIZE: 5
LAYER_SCALE: None
OFFSET_SCALE: 1.0
MLP_RATIO: 4.0
POST_NORM: True
LEVEL2_POST_NORM: True
LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29, 35, 41, 47]
CENTER_FEATURE_SCALE: True
USE_CLIP_PROJECTOR: True
TRAIN:
EMA:
ENABLE: true
DECAY: 0.9999
EPOCHS: 20
WARMUP_EPOCHS: 2
WEIGHT_DECAY: 0.05
BASE_LR: 2e-05 # 512
WARMUP_LR: .0
MIN_LR: .0
LR_LAYER_DECAY: true
LR_LAYER_DECAY_RATIO: 0.9
USE_CHECKPOINT: true
OPTIMIZER:
DCN_LR_MUL: 0.1
AMP_OPT_LEVEL: O0
EVAL_FREQ: 1
\ No newline at end of file
DATA:
IMG_SIZE: 224
IMG_ON_MEMORY: True
AUG:
MIXUP: 0.0
CUTMIX: 0.0
REPROB: 0.0
MODEL:
TYPE: intern_image
DROP_PATH_RATE: 0.6
LABEL_SMOOTHING: 0.3
INTERN_IMAGE:
CORE_OP: 'DCNv3'
DEPTHS: [6, 6, 32, 6]
GROUPS: [10, 20, 40, 80]
CHANNELS: 320
DW_KERNEL_SIZE: 5
LAYER_SCALE: None
OFFSET_SCALE: 1.0
MLP_RATIO: 4.0
POST_NORM: False
RES_POST_NORM: True
LEVEL2_POST_NORM: True
LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29]
CENTER_FEATURE_SCALE: True
USE_CLIP_PROJECTOR: True
TRAIN:
EMA:
ENABLE: true
DECAY: 0.9998
EPOCHS: 30
WARMUP_EPOCHS: 0
WEIGHT_DECAY: 1e-8
BASE_LR: 3e-05 # 512
WARMUP_LR: 3e-08
MIN_LR: 3e-07
LR_LAYER_DECAY: true
LR_LAYER_DECAY_RATIO: 0.8
RAND_INIT_FT_HEAD: true
USE_CHECKPOINT: true
AMP_OPT_LEVEL: O0
EVAL_FREQ: 1
\ No newline at end of file
DATA:
IMG_SIZE: 640
IMG_ON_MEMORY: True
AUG:
MIXUP: 0.0
CUTMIX: 0.0
REPROB: 0.0
MODEL:
TYPE: intern_image
DROP_PATH_RATE: 0.2
LABEL_SMOOTHING: 0.3
INTERN_IMAGE:
CORE_OP: 'DCNv3'
DEPTHS: [6, 6, 32, 6]
GROUPS: [10, 20, 40, 80]
CHANNELS: 320
DW_KERNEL_SIZE: 5
LAYER_SCALE: None
OFFSET_SCALE: 1.0
MLP_RATIO: 4.0
POST_NORM: False
RES_POST_NORM: True
LEVEL2_POST_NORM: True
LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29]
CENTER_FEATURE_SCALE: True
USE_CLIP_PROJECTOR: True
TRAIN:
EMA:
ENABLE: true
DECAY: 0.9999
EPOCHS: 20
WARMUP_EPOCHS: 2
WEIGHT_DECAY: 0.05
BASE_LR: 2e-05 # 512
WARMUP_LR: .0
MIN_LR: .0
LR_LAYER_DECAY: true
LR_LAYER_DECAY_RATIO: 0.9
USE_CHECKPOINT: true
OPTIMIZER:
USE_ZERO: True
DCN_LR_MUL: 0.1
AMP_OPT_LEVEL: O0
EVAL_FREQ: 1
\ No newline at end of file
......@@ -21,6 +21,12 @@ def build_model(config):
post_norm=config.MODEL.INTERN_IMAGE.POST_NORM,
mlp_ratio=config.MODEL.INTERN_IMAGE.MLP_RATIO,
with_cp=config.TRAIN.USE_CHECKPOINT,
res_post_norm=config.MODEL.INTERN_IMAGE.RES_POST_NORM, # for InternImage-H/G
dw_kernel_size=config.MODEL.INTERN_IMAGE.DW_KERNEL_SIZE, # for InternImage-H/G
use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G
level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE # for InternImage-H/G
)
else:
raise NotImplementedError(f"Unkown model: {model_type}")
......
......@@ -9,6 +9,7 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, DropPath
from ops_dcnv3 import modules as opsm
import torch.nn.functional as F
class to_channels_first(nn.Module):
......@@ -62,6 +63,147 @@ def build_act_layer(act_layer):
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class CrossAttention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None,
out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, bool_masked_pos=None, k=None, v=None):
# import pdb; pdb.set_trace()
# print("1", x.shape, k.shape, v.shape)
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
# k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
# print("2", q.shape, k.shape, v.shape)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer="LN",
attn_head_dim=None,
out_dim=None):
super().__init__()
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
self.cross_dcn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
out_dim=out_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self,
x_q,
x_kv,
pos_q,
pos_k,
bool_masked_pos,
rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_dcn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv = x
pos_q, pos_k = 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k,
bool_masked_pos=None,
rel_pos_bias=None)
x = x.squeeze(1)
return x
class StemLayer(nn.Module):
r""" Stem layer of InternImage
Args:
......@@ -180,6 +322,7 @@ class InternImageLayer(nn.Module):
core_op,
channels,
groups,
dw_kernel_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
......@@ -188,7 +331,9 @@ class InternImageLayer(nn.Module):
post_norm=False,
layer_scale=None,
offset_scale=1.0,
with_cp=False):
with_cp=False,
res_post_norm=False,
center_feature_scale=False):
super().__init__()
self.channels = channels
self.groups = groups
......@@ -197,15 +342,18 @@ class InternImageLayer(nn.Module):
self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm
self.dcn = core_op(channels=channels,
kernel_size=3,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer)
self.dcn = core_op(
channels=channels,
kernel_size=3,
dw_kernel_size=dw_kernel_size,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer,
center_feature_scale=center_feature_scale)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
......@@ -219,6 +367,10 @@ class InternImageLayer(nn.Module):
requires_grad=True)
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.res_post_norm = res_post_norm
if res_post_norm:
self.res_post_norm1 = build_norm_layer(channels, 'LN')
self.res_post_norm2 = build_norm_layer(channels, 'LN')
def forward(self, x):
......@@ -227,6 +379,10 @@ class InternImageLayer(nn.Module):
if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm:
shortcut = x
x = shortcut + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else:
x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
......@@ -269,6 +425,7 @@ class InternImageBlock(nn.Module):
channels,
depth,
groups,
dw_kernel_size,
downsample=True,
mlp_ratio=4.,
drop=0.,
......@@ -278,36 +435,52 @@ class InternImageBlock(nn.Module):
post_norm=False,
offset_scale=1.0,
layer_scale=None,
with_cp=False):
with_cp=False,
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G
super().__init__()
self.channels = channels
self.depth = depth
self.post_norm = post_norm
self.center_feature_scale = center_feature_scale
self.blocks = nn.ModuleList([
InternImageLayer(core_op=core_op,
channels=channels,
groups=groups,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp) for i in range(depth)
InternImageLayer(
core_op=core_op,
channels=channels,
groups=groups,
dw_kernel_size=dw_kernel_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
res_post_norm=res_post_norm,
center_feature_scale=center_feature_scale
) for i in range(depth)
])
if not self.post_norm:
if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None:
self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
)
self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False):
for blk in self.blocks:
for i, blk in enumerate(self.blocks):
x = blk(x)
if not self.post_norm:
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale:
x = self.norm(x)
if return_wo_downsample:
x_ = x
......@@ -356,6 +529,12 @@ class InternImage(nn.Module):
post_norm=False,
cls_scale=1.5,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
use_clip_projector=False, # for InternImage-H/G
level2_post_norm=False, # for InternImage-H/G
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
**kwargs):
super().__init__()
self.core_op = core_op
......@@ -366,11 +545,16 @@ class InternImage(nn.Module):
self.num_features = int(channels * 2**(self.num_levels - 1))
self.post_norm = post_norm
self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids
print(f'using core type: {core_op}')
print(f'using activation layer: {act_layer}')
print(f'using main norm layer: {norm_layer}')
print(f'using dpr: {drop_path_type}, {drop_path_rate}')
print(f"level2_post_norm: {level2_post_norm}")
print(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
print(f"res_post_norm: {res_post_norm}")
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
out_chans=channels,
......@@ -387,11 +571,14 @@ class InternImage(nn.Module):
self.levels = nn.ModuleList()
for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G
level = InternImageBlock(
core_op=getattr(opsm, core_op),
channels=int(channels * 2**i),
depth=depths[i],
groups=groups[i],
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
......@@ -401,20 +588,47 @@ class InternImage(nn.Module):
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp)
with_cp=with_cp,
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G
)
self.levels.append(level)
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
if not use_clip_projector: # for InternImage-T/S/B/L/XL
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
else: # for InternImage-H/G
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768
self.dcnv3_head_x4 = nn.Sequential(
nn.Conv2d(in_channels=self.num_features,
out_channels=pretrain_embed_dim * (_stride ** 2),
kernel_size=1), nn.PixelShuffle(_stride))
self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2,
out_channels=pretrain_embed_dim,
kernel_size=1)
self.clip_projector = AttentionPoolingBlock(
dim=pretrain_embed_dim,
num_heads=attnpool_num_heads,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
norm_layer=norm_layer,
out_dim=clip_embed_dim)
self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6)
self.head = nn.Linear(
clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
......@@ -480,8 +694,31 @@ class InternImage(nn.Module):
x, x_ = level(x, return_wo_downsample=True)
seq_out.append(x_)
return seq_out
def forward_clip_projector(self, x): # for InternImage-H/G
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = self.dcnv3_head_x4(x4)
x = x4
x3 = self.dcnv3_head_x3(x3)
x = x + x3
x = x.flatten(-2).transpose(1, 2).contiguous()
x = self.clip_projector(x)
x = self.fc_norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.use_clip_projector: # for InternImage-H/G
x = self.forward_clip_projector(x)
else: # for InternImage-T/S/B/L/XL
x = self.forward_features(x)
x = self.head(x)
return x
......@@ -9,6 +9,7 @@ from __future__ import print_function
from __future__ import division
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
......@@ -76,11 +77,31 @@ def _is_power_of_2(n):
return (n & (n-1) == 0) and n != 0
class CenterFeatureScaleModule(nn.Module):
def forward(self,
query,
center_feature_scale_proj_weight,
center_feature_scale_proj_bias):
center_feature_scale = F.linear(query,
weight=center_feature_scale_proj_weight,
bias=center_feature_scale_proj_bias).sigmoid()
return center_feature_scale
class DCNv3_pytorch(nn.Module):
def __init__(
self, channels=64, kernel_size=3, stride=1,
pad=1, dilation=1, group=4, offset_scale=1.0,
act_layer='GELU', norm_layer='LN'):
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False):
"""
DCNv3 Module
:param channels
......@@ -98,6 +119,7 @@ class DCNv3_pytorch(nn.Module):
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
......@@ -107,20 +129,22 @@ class DCNv3_pytorch(nn.Module):
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = 1
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=kernel_size,
kernel_size=dw_kernel_size,
stride=1,
padding=(kernel_size-1)//2,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
......@@ -137,6 +161,13 @@ class DCNv3_pytorch(nn.Module):
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
......@@ -156,6 +187,7 @@ class DCNv3_pytorch(nn.Module):
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
......@@ -171,6 +203,13 @@ class DCNv3_pytorch(nn.Module):
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
......@@ -178,9 +217,18 @@ class DCNv3_pytorch(nn.Module):
class DCNv3(nn.Module):
def __init__(
self, channels=64, kernel_size=3, stride=1,
pad=1, dilation=1, group=4, offset_scale=1.0,
act_layer='GELU', norm_layer='LN'):
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False):
"""
DCNv3 Module
:param channels
......@@ -198,6 +246,7 @@ class DCNv3(nn.Module):
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
......@@ -207,20 +256,22 @@ class DCNv3(nn.Module):
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = 1
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=kernel_size,
kernel_size=dw_kernel_size,
stride=1,
padding=(kernel_size-1)//2,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
......@@ -237,6 +288,13 @@ class DCNv3(nn.Module):
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
......@@ -256,6 +314,7 @@ class DCNv3(nn.Module):
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
dtype = x.dtype
x1 = input.permute(0, 3, 1, 2)
......@@ -273,6 +332,14 @@ class DCNv3(nn.Module):
self.group, self.group_channels,
self.offset_scale,
256)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment