Commit 43e508f4 authored by Zhe Chen's avatar Zhe Chen Committed by zhe chen
Browse files

Replace github release with hugging face (#42)



* Update function descriptions

* Update README.md

* Replace GitHub release with hugging face

* Release InternImage-H segmentation model

---------
Co-authored-by: default avatarZhenhang Huang <prc_hzh@163.com>
parent 4bbef509
import torch
import argparse
import math
from collections import OrderedDict
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('filename', nargs='?', type=str, default=None)
args = parser.parse_args()
def gen_grid(n_heads):
n_heads = n_heads
n_points = 9
points_list = []
kernel_size = int(math.sqrt(n_points))
y, x = torch.meshgrid(
torch.linspace(
(-kernel_size // 2 + 1),
(kernel_size // 2), kernel_size,
dtype=torch.float32),
torch.linspace(
(-kernel_size // 2 + 1),
(kernel_size // 2), kernel_size,
dtype=torch.float32))
points_list.extend([y, x])
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
repeat(1, n_heads, 1).permute(1, 0, 2)
return grid
def remove_ab(m):
new_sd = OrderedDict()
n_points = 9
for k, v in m.items():
if 'alpha_beta' in k:
ab = v
ab = ab.repeat(1, n_points)
h, _ = ab.size()
offset_b = k.replace('alpha_beta', 'sampling_offsets.bias')
ob = m[offset_b]
grid = gen_grid(h)
grid = grid.reshape(h, -1)
delta = (ab - 1) * grid
delta = delta.reshape(-1)
ob = ob + delta
new_sd[offset_b] = ob
continue
if 'sampling_offsets.bias' in k:
continue
new_sd[k] = v
return new_sd
model = torch.load(args.filename, map_location=torch.device('cpu'))
model = model['state_dict']
model = remove_ab(model)
new_model = {"state_dict": model}
torch.save(new_model, args.filename.replace(".pth", "_rmab.pth"))
print("finished!")
\ No newline at end of file
import torch
import argparse
import math
from collections import OrderedDict
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('filename', nargs='?', type=str, default=None)
args = parser.parse_args()
def gen_grid(n_heads):
n_heads = n_heads
n_points = 9
points_list = []
kernel_size = int(math.sqrt(n_points))
y, x = torch.meshgrid(
torch.linspace((-kernel_size // 2 + 1), (kernel_size // 2),
kernel_size,
dtype=torch.float32),
torch.linspace((-kernel_size // 2 + 1), (kernel_size // 2),
kernel_size,
dtype=torch.float32))
points_list.extend([y, x])
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
repeat(1, n_heads, 1).permute(1, 0, 2)
return grid
def convert_to_newop(m):
new_sd = OrderedDict()
n_points = 9
for k, v in m.items():
new_k = k
if 'attn' in k:
new_k = new_k.replace('attn', 'dcn')
if 'sampling_offsets' in k:
new_k = new_k.replace('sampling_offsets', 'offset')
if 'attention_weights' in k:
new_k = new_k.replace('attention_weights', 'mask')
if 'value_proj' in k:
new_k = new_k.replace('value_proj', 'input_proj')
if 'ema' in k:
continue
if ".norm1_k." in k:
new_k = new_k.replace('.norm1_k.', '.norm1_k.0.')
if ".norm1_q." in k:
new_k = new_k.replace('.norm1_q.', '.norm1_q.0.')
if ".norm1_v." in k:
new_k = new_k.replace('.norm1_v.', '.norm1_v.0.')
if ".post_norms." in k:
new_k = new_k.replace('.bias', '.0.bias')
new_k = new_k.replace('.weight', '.0.weight')
if "fc_norm." in k:
new_k = new_k.replace('fc_norm.', 'fc_norm.0.')
new_sd[new_k] = v.half()
return new_sd
model = torch.load(args.filename, map_location=torch.device('cpu'))['state_dict']
new_model = {"state_dict": convert_to_newop(model)}
torch.save(new_model, args.filename.replace(".pth", "_rename.pth"))
...@@ -13,9 +13,11 @@ from mmcv.runner import _load_checkpoint ...@@ -13,9 +13,11 @@ from mmcv.runner import _load_checkpoint
from mmcv.cnn import constant_init, trunc_normal_init from mmcv.cnn import constant_init, trunc_normal_init
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES from mmseg.models.builder import BACKBONES
import torch.nn.functional as F
from ops_dcnv3 import modules as opsm from ops_dcnv3 import modules as opsm
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
def __init__(self): def __init__(self):
...@@ -69,6 +71,171 @@ def build_act_layer(act_layer): ...@@ -69,6 +71,171 @@ def build_act_layer(act_layer):
raise NotImplementedError(f'build_act_layer does not support {act_layer}') raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class CrossAttention(nn.Module):
r""" Cross Attention Module
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
attn_head_dim (int, optional): Dimension of attention head.
out_dim (int, optional): Dimension of output.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None,
out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
r"""Attentive Block
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float, optional): Dropout rate. Default: 0.0.
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
drop_path (float | tuple[float], optional): Stochastic depth rate.
Default: 0.0.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
attn_head_dim (int, optional): Dimension of attention head. Default: None.
out_dim (int, optional): Dimension of output. Default: None.
"""
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer="LN",
attn_head_dim=None,
out_dim=None):
super().__init__()
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
self.cross_dcn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
out_dim=out_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self,
x_q,
x_kv,
pos_q,
pos_k,
bool_masked_pos,
rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_dcn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv = x
pos_q, pos_k = 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k,
bool_masked_pos=None,
rel_pos_bias=None)
x = x.squeeze(1)
return x
class StemLayer(nn.Module): class StemLayer(nn.Module):
r""" Stem layer of InternImage r""" Stem layer of InternImage
Args: Args:
...@@ -195,7 +362,10 @@ class InternImageLayer(nn.Module): ...@@ -195,7 +362,10 @@ class InternImageLayer(nn.Module):
post_norm=False, post_norm=False,
layer_scale=None, layer_scale=None,
offset_scale=1.0, offset_scale=1.0,
with_cp=False): with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.groups = groups self.groups = groups
...@@ -204,15 +374,18 @@ class InternImageLayer(nn.Module): ...@@ -204,15 +374,18 @@ class InternImageLayer(nn.Module):
self.norm1 = build_norm_layer(channels, 'LN') self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm self.post_norm = post_norm
self.dcn = core_op(channels=channels, self.dcn = core_op(
kernel_size=3, channels=channels,
stride=1, kernel_size=3,
pad=1, stride=1,
dilation=1, pad=1,
group=groups, dilation=1,
offset_scale=offset_scale, group=groups,
act_layer=act_layer, offset_scale=offset_scale,
norm_layer=norm_layer) act_layer=act_layer,
norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale) # for InternImage-H/G
self.drop_path = DropPath(drop_path) if drop_path > 0. \ self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity() else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN') self.norm2 = build_norm_layer(channels, 'LN')
...@@ -226,6 +399,10 @@ class InternImageLayer(nn.Module): ...@@ -226,6 +399,10 @@ class InternImageLayer(nn.Module):
requires_grad=True) requires_grad=True)
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels), self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True) requires_grad=True)
self.res_post_norm = res_post_norm
if res_post_norm:
self.res_post_norm1 = build_norm_layer(channels, 'LN')
self.res_post_norm2 = build_norm_layer(channels, 'LN')
def forward(self, x): def forward(self, x):
...@@ -234,6 +411,9 @@ class InternImageLayer(nn.Module): ...@@ -234,6 +411,9 @@ class InternImageLayer(nn.Module):
if self.post_norm: if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x))) x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x))) x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm: # for InternImage-H/G
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else: else:
x = x + self.drop_path(self.dcn(self.norm1(x))) x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
...@@ -285,36 +465,54 @@ class InternImageBlock(nn.Module): ...@@ -285,36 +465,54 @@ class InternImageBlock(nn.Module):
post_norm=False, post_norm=False,
offset_scale=1.0, offset_scale=1.0,
layer_scale=None, layer_scale=None,
with_cp=False): with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.depth = depth self.depth = depth
self.post_norm = post_norm self.post_norm = post_norm
self.center_feature_scale = center_feature_scale
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
InternImageLayer(core_op=core_op, InternImageLayer(
channels=channels, core_op=core_op,
groups=groups, channels=channels,
mlp_ratio=mlp_ratio, groups=groups,
drop=drop, mlp_ratio=mlp_ratio,
drop_path=drop_path[i] if isinstance( drop=drop,
drop_path, list) else drop_path, drop_path=drop_path[i] if isinstance(
act_layer=act_layer, drop_path, list) else drop_path,
norm_layer=norm_layer, act_layer=act_layer,
post_norm=post_norm, norm_layer=norm_layer,
layer_scale=layer_scale, post_norm=post_norm,
offset_scale=offset_scale, layer_scale=layer_scale,
with_cp=with_cp) for i in range(depth) offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G
) for i in range(depth)
]) ])
if not self.post_norm: if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN') self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None: # for InternImage-H/G
self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
)
self.downsample = DownsampleLayer( self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False): def forward(self, x, return_wo_downsample=False):
for blk in self.blocks: for i, blk in enumerate(self.blocks):
x = blk(x) x = blk(x)
if not self.post_norm: if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale:
x = self.norm(x) x = self.norm(x)
if return_wo_downsample: if return_wo_downsample:
x_ = x x_ = x
...@@ -344,6 +542,11 @@ class InternImage(nn.Module): ...@@ -344,6 +542,11 @@ class InternImage(nn.Module):
layer_scale (bool): Whether to use layer scale. Default: False layer_scale (bool): Whether to use layer scale. Default: False
cls_scale (bool): Whether to use class scale. Default: False cls_scale (bool): Whether to use class scale. Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some with_cp (bool): Use checkpoint or not. Using checkpoint will save some
dw_kernel_size (int): Size of the dwconv. Default: None
level2_post_norm (bool): Whether to use level2 post norm. Default: False
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
res_post_norm (bool): Whether to use res post norm. Default: False
center_feature_scale (bool): Whether to use center feature scale. Default: False
""" """
def __init__(self, def __init__(self,
...@@ -361,6 +564,11 @@ class InternImage(nn.Module): ...@@ -361,6 +564,11 @@ class InternImage(nn.Module):
offset_scale=1.0, offset_scale=1.0,
post_norm=False, post_norm=False,
with_cp=False, with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
level2_post_norm=False, # for InternImage-H/G
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
...@@ -374,10 +582,15 @@ class InternImage(nn.Module): ...@@ -374,10 +582,15 @@ class InternImage(nn.Module):
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.init_cfg = init_cfg self.init_cfg = init_cfg
self.out_indices = out_indices self.out_indices = out_indices
print(f'using core type: {core_op}') self.level2_post_norm_block_ids = level2_post_norm_block_ids
print(f'using activation layer: {act_layer}') logger = get_root_logger()
print(f'using main norm layer: {norm_layer}') logger.info(f'using core type: {core_op}')
print(f'using dpr: {drop_path_type}, {drop_path_rate}') logger.info(f'using activation layer: {act_layer}')
logger.info(f'using main norm layer: {norm_layer}')
logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}')
logger.info(f"level2_post_norm: {level2_post_norm}")
logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
logger.info(f"res_post_norm: {res_post_norm}")
in_chans = 3 in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans, self.patch_embed = StemLayer(in_chans=in_chans,
...@@ -395,6 +608,8 @@ class InternImage(nn.Module): ...@@ -395,6 +608,8 @@ class InternImage(nn.Module):
self.levels = nn.ModuleList() self.levels = nn.ModuleList()
for i in range(self.num_levels): for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G
level = InternImageBlock( level = InternImageBlock(
core_op=getattr(opsm, core_op), core_op=getattr(opsm, core_op),
channels=int(channels * 2**i), channels=int(channels * 2**i),
...@@ -409,7 +624,12 @@ class InternImage(nn.Module): ...@@ -409,7 +624,12 @@ class InternImage(nn.Module):
downsample=(i < self.num_levels - 1), downsample=(i < self.num_levels - 1),
layer_scale=layer_scale, layer_scale=layer_scale,
offset_scale=offset_scale, offset_scale=offset_scale,
with_cp=with_cp) with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G
)
self.levels.append(level) self.levels.append(level)
self.num_layers = len(depths) self.num_layers = len(depths)
......
...@@ -9,6 +9,7 @@ from __future__ import print_function ...@@ -9,6 +9,7 @@ from __future__ import print_function
from __future__ import division from __future__ import division
import warnings import warnings
import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_ from torch.nn.init import xavier_uniform_, constant_
...@@ -72,21 +73,41 @@ def _is_power_of_2(n): ...@@ -72,21 +73,41 @@ def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError( raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
return (n & (n-1) == 0) and n != 0
class CenterFeatureScaleModule(nn.Module):
def forward(self,
query,
center_feature_scale_proj_weight,
center_feature_scale_proj_bias):
center_feature_scale = F.linear(query,
weight=center_feature_scale_proj_weight,
bias=center_feature_scale_proj_bias).sigmoid()
return center_feature_scale
class DCNv3_pytorch(nn.Module): class DCNv3_pytorch(nn.Module):
def __init__( def __init__(
self, channels=64, kernel_size=3, stride=1, self,
pad=1, dilation=1, group=4, offset_scale=1.0, channels=64,
act_layer='GELU', norm_layer='LN'): kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
:param kernel_size :param kernel_size
:param stride :param stride
:param pad :param pad
:param dilation :param dilation
:param group :param group
:param offset_scale :param offset_scale
...@@ -98,29 +119,32 @@ class DCNv3_pytorch(nn.Module): ...@@ -98,29 +119,32 @@ class DCNv3_pytorch(nn.Module):
raise ValueError( raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}') f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group _d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") "which is more efficient in our CUDA implementation.")
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride self.stride = stride
self.dilation = 1 self.dilation = dilation
self.pad = pad self.pad = pad
self.group = group self.group = group
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
channels, channels,
channels, channels,
kernel_size=kernel_size, kernel_size=dw_kernel_size,
stride=1, stride=1,
padding=(kernel_size-1)//2, padding=(dw_kernel_size - 1) // 2,
groups=channels), groups=channels),
build_norm_layer( build_norm_layer(
channels, channels,
...@@ -137,7 +161,14 @@ class DCNv3_pytorch(nn.Module): ...@@ -137,7 +161,14 @@ class DCNv3_pytorch(nn.Module):
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self): def _reset_parameters(self):
constant_(self.offset.weight.data, 0.) constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.) constant_(self.offset.bias.data, 0.)
...@@ -171,22 +202,38 @@ class DCNv3_pytorch(nn.Module): ...@@ -171,22 +202,38 @@ class DCNv3_pytorch(nn.Module):
self.dilation, self.dilation, self.dilation, self.dilation,
self.group, self.group_channels, self.group, self.group_channels,
self.offset_scale) self.offset_scale)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x) x = self.output_proj(x)
return x return x
class DCNv3(nn.Module): class DCNv3(nn.Module):
def __init__( def __init__(
self, channels=64, kernel_size=3, stride=1, self,
pad=1, dilation=1, group=4, offset_scale=1.0, channels=64,
act_layer='GELU', norm_layer='LN'): kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
:param kernel_size :param kernel_size
:param stride :param stride
:param pad :param pad
:param dilation :param dilation
:param group :param group
:param offset_scale :param offset_scale
...@@ -198,29 +245,32 @@ class DCNv3(nn.Module): ...@@ -198,29 +245,32 @@ class DCNv3(nn.Module):
raise ValueError( raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}') f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group _d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") "which is more efficient in our CUDA implementation.")
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride self.stride = stride
self.dilation = 1 self.dilation = dilation
self.pad = pad self.pad = pad
self.group = group self.group = group
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
channels, channels,
channels, channels,
kernel_size=kernel_size, kernel_size=dw_kernel_size,
stride=1, stride=1,
padding=(kernel_size-1)//2, padding=(dw_kernel_size - 1) // 2,
groups=channels), groups=channels),
build_norm_layer( build_norm_layer(
channels, channels,
...@@ -237,7 +287,14 @@ class DCNv3(nn.Module): ...@@ -237,7 +287,14 @@ class DCNv3(nn.Module):
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self): def _reset_parameters(self):
constant_(self.offset.weight.data, 0.) constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.) constant_(self.offset.bias.data, 0.)
...@@ -273,6 +330,14 @@ class DCNv3(nn.Module): ...@@ -273,6 +330,14 @@ class DCNv3(nn.Module):
self.group, self.group_channels, self.group, self.group_channels,
self.offset_scale, self.offset_scale,
256) 256)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x) x = self.output_proj(x)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment