Commit fe6cdd2e authored by zhe chen's avatar zhe chen
Browse files

Update huggingface model


Update huggingface model


Update README.md


Update README.md


Update README.md


Update huggingface model


Update huggingface model
parent 3bd2e7b9
{
"_name_or_path": "OpenGVLab/internimage_s_1k_224",
"act_layer": "GELU",
"architectures": [
"InternImageModel"
],
"auto_map": {
"AutoConfig": "configuration_internimage.InternImageConfig",
"AutoModel": "modeling_internimage.InternImageModel",
"AutoModelForImageClassification": "modeling_internimage.InternImageModelForImageClassification"
},
"center_feature_scale": false,
"channels": 80,
"cls_scale": 1.5,
"core_op": "DCNv3",
"depths": [
4,
4,
21,
4
],
"drop_path_rate": 0.0,
"drop_path_type": "linear",
"drop_rate": 0.0,
"dw_kernel_size": null,
"groups": [
5,
10,
20,
40
],
"layer_scale": 1e-05,
"level2_post_norm": false,
"level2_post_norm_block_ids": null,
"mlp_ratio": 4.0,
"model_type": "internimage",
"norm_layer": "LN",
"num_classes": 1000,
"offset_scale": 1.0,
"post_norm": true,
"remove_center": false,
"res_post_norm": false,
"torch_dtype": "float32",
"transformers_version": "4.37.2",
"use_clip_projector": false,
"with_cp": false
}
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers import PretrainedConfig
class InternImageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`~InternImageModel`].
It is used to instantiate an internimage model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the internimage [OpenGVLab/internimage](https://huggingface.co/OpenGVLab/internimage) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
core_op (`str`, *optional*, defaults to `"DCNv3"`):
Core operation used in the InternImageModel.
depths (`tuple`, *optional*, defaults to `(4, 4, 18, 4)`):
Tuple specifying the depth of layers in the InternImageModel.
groups (`tuple`, *optional*, defaults to `(4, 8, 16, 32)`):
Tuple specifying the group of layers in the InternImageModel.
channels (`int`, *optional*, defaults to `64`):
Number of channels in the InternImageModel.
dw_kernel_size (`int`, *optional*, defaults to `None`):
Kernel size for depthwise convolutions.
layer_scale (`float`, *optional*, defaults to `None`):
Scale of the layers in the model.
offset_scale (`float`, *optional*, defaults to `1.0`):
Offset scale in the model.
mlp_ratio (`float`, *optional*, defaults to `4.0`):
Ratio of mlp layers in the InternImageModel.
post_norm (`bool`, *optional*, defaults to `False`):
Whether to use post normalization in the model.
level2_post_norm (`bool`, *optional*, defaults to `False`):
Whether to use level 2 post normalization.
level2_post_norm_block_ids (`list`, *optional*, defaults to `None`):
Specific block IDs for level 2 post normalization.
center_feature_scale (`bool`, *optional*, defaults to `False`):
Whether to apply center feature scaling.
use_clip_projector (`bool`, *optional*, defaults to `False`):
Whether to use CLIP projector.
remove_center (`bool`, *optional*, defaults to `False`):
Whether to remove center pixels in some operations.
num_classes (`int`, *optional*, defaults to `1000`):
Number of classes for the model output.
drop_rate (`float`, *optional*, defaults to `0.0`):
Dropout rate in the model.
drop_path_rate (`float`, *optional*, defaults to `0.0`):
Dropout path rate in the model.
drop_path_type (`str`, *optional*, defaults to `"linear"`):
Type of dropout path used in the model.
act_layer (`str`, *optional*, defaults to `"GELU"`):
Activation function used in the model.
norm_layer (`str`, *optional*, defaults to `"LN"`):
Normalization layer used in the model.
cls_scale (`float`, *optional*, defaults to `1.5`):
Scale of the classification layer in the model.
with_cp (`bool`, *optional*, defaults to `False`):
Whether to use checkpointing in the model.
"""
model_type = 'internimage'
def __init__(
self,
core_op='DCNv3',
depths=(4, 4, 18, 4),
groups=(4, 8, 16, 32),
channels=64,
dw_kernel_size=None,
layer_scale=None,
offset_scale=1.0,
mlp_ratio=4.0,
post_norm=False,
res_post_norm=False,
level2_post_norm=False,
level2_post_norm_block_ids=None,
center_feature_scale=False,
use_clip_projector=False,
remove_center=False,
num_classes=1000,
drop_rate=0.0,
drop_path_rate=0.0,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
cls_scale=1.5,
with_cp=False,
**kwargs,
):
super().__init__(**kwargs)
# Model configuration parameters
self.core_op = core_op
self.depths = depths
self.groups = groups
self.channels = channels
self.dw_kernel_size = dw_kernel_size
self.layer_scale = layer_scale
self.offset_scale = offset_scale
self.mlp_ratio = mlp_ratio
self.post_norm = post_norm
self.res_post_norm = res_post_norm
self.level2_post_norm = level2_post_norm
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.center_feature_scale = center_feature_scale
self.use_clip_projector = use_clip_projector
self.remove_center = remove_center
self.num_classes = num_classes
self.drop_rate = drop_rate
self.drop_path_rate = drop_path_rate
self.drop_path_type = drop_path_type
self.act_layer = act_layer
self.norm_layer = norm_layer
self.cls_scale = cls_scale
self.with_cp = with_cp
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import, division, print_function
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch, has_cuda_kernel
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class CenterFeatureScaleModule(nn.Module):
def forward(self,
query,
center_feature_scale_proj_weight,
center_feature_scale_proj_bias):
center_feature_scale = F.linear(query,
weight=center_feature_scale_proj_weight,
bias=center_feature_scale_proj_bias).sigmoid()
return center_feature_scale
class DCNv3_pytorch(nn.Module):
def __init__(
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
:param kernel_size
:param stride
:param pad
:param dilation
:param group
:param offset_scale
:param act_layer
:param norm_layer
"""
super().__init__()
if channels % group != 0:
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=dw_kernel_size,
stride=1,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
norm_layer,
'channels_first',
'channels_last'),
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.)
constant_(self.mask.weight.data, 0.)
constant_(self.mask.bias.data, 0.)
xavier_uniform_(self.input_proj.weight.data)
constant_(self.input_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, input):
"""
:param query (N, H, W, C)
:return output (N, H, W, C)
"""
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1)
x = dcnv3_core_pytorch(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale, self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
class DCNv3(nn.Module):
def __init__(
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
:param kernel_size
:param stride
:param pad
:param dilation
:param group
:param offset_scale
:param act_layer
:param norm_layer
"""
super().__init__()
if channels % group != 0:
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
if self.remove_center and self.kernel_size % 2 == 0:
raise ValueError('remove_center is only compatible with odd kernel size.')
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=dw_kernel_size,
stride=1,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
norm_layer,
'channels_first',
'channels_last'),
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.)
constant_(self.mask.weight.data, 0.)
constant_(self.mask.bias.data, 0.)
xavier_uniform_(self.input_proj.weight.data)
constant_(self.input_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, input):
"""
:param query (N, H, W, C)
:return output (N, H, W, C)
"""
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
dtype = x.dtype
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1)
mask = mask.reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import, division, print_function
try:
import DCNv3
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
has_cuda_kernel = True
except:
has_cuda_kernel = False
import pkg_resources
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
class DCNv3Function(Function):
@staticmethod
@custom_fwd
def forward(
ctx, input, offset, mask,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale, im2col_step, remove_center):
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.dilation_h = dilation_h
ctx.dilation_w = dilation_w
ctx.group = group
ctx.group_channels = group_channels
ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
args = [
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step
]
if remove_center or dcn_version > 1.0:
args.append(remove_center)
output = DCNv3.dcnv3_forward(*args)
ctx.save_for_backward(input, offset, mask)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors
args = [
input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step
]
if ctx.remove_center or dcn_version > 1.0:
args.append(ctx.remove_center)
grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward(*args)
return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, im2col_step, remove_center):
"""Symbolic function for mmdeploy::DCNv3.
Returns:
DCNv3 op for onnx.
"""
return g.op(
'mmdeploy::TRTDCNv3',
input,
offset,
mask,
kernel_h_i=int(kernel_h),
kernel_w_i=int(kernel_w),
stride_h_i=int(stride_h),
stride_w_i=int(stride_w),
pad_h_i=int(pad_h),
pad_w_i=int(pad_w),
dilation_h_i=int(dilation_h),
dilation_w_i=int(dilation_w),
group_i=int(group),
group_channels_i=int(group_channels),
offset_scale_f=float(offset_scale),
im2col_step_i=int(im2col_step),
remove_center_i=int(remove_center),
)
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
_, H_, W_, _ = spatial_shapes
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
ref_y, ref_x = torch.meshgrid(
torch.linspace(
# pad_h + 0.5,
# H_ - pad_h - 0.5,
(dilation_h * (kernel_h - 1)) // 2 + 0.5,
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
H_out,
dtype=torch.float32,
device=device),
torch.linspace(
# pad_w + 0.5,
# W_ - pad_w - 0.5,
(dilation_w * (kernel_w - 1)) // 2 + 0.5,
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
W_out,
dtype=torch.float32,
device=device))
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
ref = torch.stack((ref_x, ref_y), -1).reshape(
1, H_out, W_out, 1, 2)
return ref
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
_, H_, W_, _ = spatial_shapes
points_list = []
x, y = torch.meshgrid(
torch.linspace(
-((dilation_w * (kernel_w - 1)) // 2),
-((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w,
kernel_w,
dtype=torch.float32,
device=device),
torch.linspace(
-((dilation_h * (kernel_h - 1)) // 2),
-((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h,
kernel_h,
dtype=torch.float32,
device=device))
points_list.extend([x / W_, y / H_])
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
repeat(1, group, 1).permute(1, 0, 2)
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
return grid
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
idx = list(range(sampling_locations.shape[-2]))
C = (kernel_w * kernel_h - 1)//2
idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0]
sampling_locations = sampling_locations[:,:,:,idx, :]
return sampling_locations
def dcnv3_core_pytorch(
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, remove_center):
# for debug and test only,
# need to use cuda version instead
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
raise ValueError('remove_center is only compatible with square odd kernel size.')
input = F.pad(
input,
[0, 0, pad_h, pad_h, pad_w, pad_w])
N_, H_in, W_in, _ = input.shape
_, H_out, W_out, _ = offset.shape
ref = _get_reference_points(
input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)
grid = _generate_dilation_grids(
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device)
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
if remove_center:
sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
sampling_locations = sampling_locations.flatten(3, 4)
sampling_locations = sampling_locations + offset * offset_scale / spatial_norm
P_ = kernel_h * kernel_w - remove_center
sampling_grids = 2 * sampling_locations - 1
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
# (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_)
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\
reshape(N_*group, 1, H_out*W_out, P_)
output = (sampling_input_ * mask).sum(-1).view(N_,
group*group_channels, H_out*W_out)
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_internimage import InternImageConfig
from .dcnv3 import DCNv3, DCNv3_pytorch, has_cuda_kernel
from .dcnv3_func import dcnv3_core_pytorch
@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.
"""
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
pooler_output: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class CrossAttention(nn.Module):
r""" Cross Attention Module
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
attn_head_dim (int, optional): Dimension of attention head.
out_dim (int, optional): Dimension of output.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None,
out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
r"""Attentive Block
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float, optional): Dropout rate. Default: 0.0.
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
drop_path (float | tuple[float], optional): Stochastic depth rate.
Default: 0.0.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
attn_head_dim (int, optional): Dimension of attention head. Default: None.
out_dim (int, optional): Dimension of output. Default: None.
"""
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer='LN',
attn_head_dim=None,
out_dim=None):
super().__init__()
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
self.cross_dcn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
out_dim=out_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self,
x_q,
x_kv,
pos_q,
pos_k,
bool_masked_pos,
rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_dcn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv = x
pos_q, pos_k = 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k,
bool_masked_pos=None,
rel_pos_bias=None)
x = x.squeeze(1)
return x
class StemLayer(nn.Module):
r"""Stem layer of InternImage
Args:
in_chans (int): number of input channels
out_chans (int): number of output channels
act_layer (str): activation layer
norm_layer (str): normalization layer
"""
def __init__(self,
in_chans=3,
out_chans=96,
act_layer='GELU',
norm_layer='BN'):
super().__init__()
self.conv1 = nn.Conv2d(in_chans,
out_chans // 2,
kernel_size=3,
stride=2,
padding=1)
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
'channels_first', 'channels_first')
self.act = build_act_layer(act_layer)
self.conv2 = nn.Conv2d(out_chans // 2,
out_chans,
kernel_size=3,
stride=2,
padding=1)
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
'channels_last')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class DownsampleLayer(nn.Module):
r"""Downsample layer of InternImage
Args:
channels (int): number of input channels
norm_layer (str): normalization layer
"""
def __init__(self, channels, norm_layer='LN'):
super().__init__()
self.conv = nn.Conv2d(channels,
2 * channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm = build_norm_layer(2 * channels, norm_layer,
'channels_first', 'channels_last')
def forward(self, x):
x = self.conv(x.permute(0, 3, 1, 2))
x = self.norm(x)
return x
class MLPLayer(nn.Module):
r"""MLP layer of InternImage
Args:
in_features (int): number of input features
hidden_features (int): number of hidden features
out_features (int): number of output features
act_layer (str): activation layer
drop (float): dropout rate
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer='GELU',
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_act_layer(act_layer)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InternImageLayer(nn.Module):
r"""Basic layer of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
groups,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
layer_scale=None,
offset_scale=1.0,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.groups = groups
self.mlp_ratio = mlp_ratio
self.with_cp = with_cp
self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm
self.dcn = core_op(
channels=channels,
kernel_size=3,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
self.mlp = MLPLayer(in_features=channels,
hidden_features=int(channels * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.layer_scale = layer_scale is not None
if self.layer_scale:
self.layer_scale1 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.layer_scale2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.res_post_norm = res_post_norm
if res_post_norm:
self.res_post_norm1 = build_norm_layer(channels, 'LN')
self.res_post_norm2 = build_norm_layer(channels, 'LN')
def forward(self, x):
def _inner_forward(x):
if not self.layer_scale:
if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm: # for InternImage-H/G
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else:
x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.post_norm:
x = x + self.drop_path(self.layer_scale1 * self.norm1(self.dcn(x)))
x = x + self.drop_path(self.layer_scale2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.layer_scale1 * self.dcn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale2 * self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = checkpoint.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class InternImageBlock(nn.Module):
r"""Block of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
depths (list): Depth of each block.
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
depth,
groups,
downsample=True,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
offset_scale=1.0,
layer_scale=None,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.depth = depth
self.post_norm = post_norm
self.center_feature_scale = center_feature_scale
self.blocks = nn.ModuleList([
InternImageLayer(
core_op=core_op,
channels=channels,
groups=groups,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
) for i in range(depth)
])
if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None: # for InternImage-H/G
self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
)
self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False):
for i, blk in enumerate(self.blocks):
x = blk(x)
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale:
x = self.norm(x)
if return_wo_downsample:
x_ = x
if self.downsample is not None:
x = self.downsample(x)
if return_wo_downsample:
return x, x_
return x
class InternImage(nn.Module):
r"""InternImage
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
https://arxiv.org/pdf/2103.14030
Args:
core_op (str): Core operator. Default: 'DCNv3'
channels (int): Number of the first stage. Default: 64
depths (list): Depth of each block. Default: [3, 4, 18, 5]
groups (list): Groups of each block. Default: [3, 6, 12, 24]
num_classes (int): Number of classes. Default: 1000
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
drop_rate (float): Probability of an element to be zeroed. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.
act_layer (str): Activation layer. Default: 'GELU'
norm_layer (str): Normalization layer. Default: 'LN'
layer_scale (float): The initial value of layer scale. Default: None
cls_scale (float): Whether to use class scale. Default: 1.5
with_cp (bool): Use gradient checkpointing or not. Default: False
dw_kernel_size (int): Size of the dwconv. Default: None
use_clip_projector (bool): Whether to use clip projector. Default: False
level2_post_norm (bool): Whether to use level2 post norm. Default: False
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
res_post_norm (bool): Whether to use res post norm. Default: False
center_feature_scale (bool): Whether to use center feature scale. Default: False
"""
def __init__(self,
core_op='DCNv3',
channels=64,
depths=[3, 4, 18, 5],
groups=[3, 6, 12, 24],
num_classes=1000,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.2,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
layer_scale=None,
offset_scale=1.0,
post_norm=False,
cls_scale=1.5,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
use_clip_projector=False, # for InternImage-H/G
level2_post_norm=False, # for InternImage-H/G
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
**kwargs):
super().__init__()
if core_op == 'DCNv3' and has_cuda_kernel:
self.core_op = DCNv3
print('DCNv3 is installed, using CUDA implementation.')
elif core_op == 'DCNv3' and not has_cuda_kernel:
self.core_op = DCNv3_pytorch
print('DCNv3 is not installed, using PyTorch implementation.')
else:
self.core_op = DCNv3_pytorch
print('Using DCNv3 PyTorch implementation.')
self.num_classes = num_classes
self.num_levels = len(depths)
self.depths = depths
self.channels = channels
self.num_features = int(channels * 2 ** (self.num_levels - 1))
self.post_norm = post_norm
self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.remove_center = remove_center
print(f'using core type: {core_op}')
print(f'level2_post_norm: {level2_post_norm}')
print(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}')
print(f'res_post_norm: {res_post_norm}')
print(f'remove_center: {remove_center}')
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
out_chans=channels,
act_layer=act_layer,
norm_layer=norm_layer)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
if drop_path_type == 'uniform':
for i in range(len(dpr)):
dpr[i] = drop_path_rate
self.levels = nn.ModuleList()
for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G
level = InternImageBlock(
core_op=self.core_op,
channels=int(channels * 2 ** i),
depth=depths[i],
groups=groups[i],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.levels.append(level)
if self.num_classes > 0:
if not use_clip_projector: # for InternImage-T/S/B/L/XL
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
else: # for InternImage-H/G
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768
self.dcnv3_head_x4 = nn.Sequential(
nn.Conv2d(in_channels=self.num_features,
out_channels=pretrain_embed_dim * (_stride ** 2),
kernel_size=1), nn.PixelShuffle(_stride))
self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2,
out_channels=pretrain_embed_dim,
kernel_size=1)
self.clip_projector = AttentionPoolingBlock(
dim=pretrain_embed_dim,
num_heads=attnpool_num_heads,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
norm_layer=norm_layer,
out_dim=clip_embed_dim)
self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6)
self.head = nn.Linear(
clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m):
if isinstance(m, self.core_op):
m._reset_parameters()
@torch.jit.ignore
def lr_decay_keywords(self, decay_ratio=0.87):
lr_ratios = {}
# blocks
idx = 0
for i in range(4):
layer_num = 3 - i # 3 2 1 0
for j in range(self.depths[layer_num]):
block_num = self.depths[layer_num] - j - 1
tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
decay = 1.0 * (decay_ratio ** idx)
lr_ratios[tag] = decay
idx += 1
# patch_embed (before stage-1)
lr_ratios['patch_embed'] = lr_ratios['levels.0.blocks.0.']
# levels.0.downsample (between stage-1 and stage-2)
lr_ratios['levels.0.downsample'] = lr_ratios['levels.1.blocks.0.']
lr_ratios['levels.0.norm'] = lr_ratios['levels.1.blocks.0.']
# levels.1.downsample (between stage-2 and stage-3)
lr_ratios['levels.1.downsample'] = lr_ratios['levels.2.blocks.0.']
lr_ratios['levels.1.norm'] = lr_ratios['levels.2.blocks.0.']
# levels.2.downsample (between stage-3 and stage-4)
lr_ratios['levels.2.downsample'] = lr_ratios['levels.3.blocks.0.']
lr_ratios['levels.2.norm'] = lr_ratios['levels.3.blocks.0.']
return lr_ratios
def forward_features_seq_out(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
seq_out = []
for level in self.levels:
x, x_ = level(x, return_wo_downsample=True)
seq_out.append(x_)
return seq_out
def forward_features(self, x):
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
hidden_states = [x1, x2, x3, x4]
if self.num_classes > 0:
x = self.conv_head(x4)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return {
'hidden_states': hidden_states,
'pooler_output': x if self.num_classes > 0 else None
}
def forward_clip_projector(self, x): # for InternImage-H/G
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
hidden_states = [x1, x2, x3, x4]
if self.num_classes > 0:
x4 = self.dcnv3_head_x4(x4)
x = x4
x3 = self.dcnv3_head_x3(x3)
x = x + x3
x = x.flatten(-2).transpose(1, 2).contiguous()
x = self.clip_projector(x)
x = self.fc_norm(x)
return {
'hidden_states': hidden_states,
'pooler_output': x if self.num_classes > 0 else None
}
def forward(self, x):
if self.use_clip_projector: # for InternImage-H/G
outputs = self.forward_clip_projector(x)
else: # for InternImage-T/S/B/L/XL
outputs = self.forward_features(x)
hidden_states = outputs['hidden_states']
pooler_output = outputs['pooler_output']
if self.num_classes > 0:
logits = self.head(pooler_output)
else:
logits = None
return BackboneOutput(
hidden_states=hidden_states,
last_hidden_state=hidden_states[-1],
pooler_output=pooler_output,
logits=logits
)
class InternImageModel(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=0,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
res_post_norm=config.res_post_norm, # for InternImage-H/G
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
class InternImageModelForImageClassification(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
res_post_norm=config.res_post_norm, # for InternImage-H/G
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
if labels is not None:
logits = outputs['logits']
loss = F.cross_entropy(logits, labels)
outputs['loss'] = loss
return outputs
{
"crop_size": 224,
"do_center_crop": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.485,
0.456,
0.406
],
"image_std": [
0.229,
0.224,
0.225
],
"resample": 3,
"size": 224
}
---
license: mit
pipeline_tag: image-classification
library_name: transformers
tags:
- internimage
- custom_code
datasets:
- ILSVRC/imagenet-1k
---
# InternImage Model Card
## Introduction
InternImage is an advanced vision foundation model developed by researchers from Shanghai AI Laboratory, Tsinghua University, and other institutions. Unlike models based on Transformers, InternImage employs DCNv3 as its core operator. This approach equips the model with dynamic and effective receptive fields required for downstream tasks like object detection and segmentation, while enabling adaptive spatial aggregation.
<div style="text-align: center;"> <img src="https://github.com/OpenGVLab/InternImage/raw/master/docs/figs/arch.png" style="width:60%;" /> </div>
## Performance
- InternImage achieved an impressive Top-1 accuracy of 90.1% on the ImageNet benchmark dataset using only publicly available data for image classification. Apart from two undisclosed models trained with additional datasets by Google and Microsoft, InternImage is the only open-source model that achieves a Top-1 accuracy of over 90.0%, and it is also the largest model in scale worldwide.
- InternImage outperformed all other models worldwide on the COCO object detection benchmark dataset with a remarkable mAP of 65.5, making it the only model that surpasses 65 mAP in the world.
- InternImage also demonstrated world's best performance on 16 other important visual benchmark datasets, covering a wide range of tasks such as classification, detection, and segmentation, making it the top-performing model across multiple domains.
## Released Models
### Open‑Source Visual Pretrained Models
| huggingface name | model name | pretrain | resolution | #param |
| :-------------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :----: |
| [internimage_l_22k_384](https://huggingface.co/OpenGVLab/internimage_l_22k_384) | InternImage-L | IN-22K | 384x384 | 223M |
| [internimage_xl_22k_384](https://huggingface.co/OpenGVLab/internimage_xl_22k_384) | InternImage-XL | IN-22K | 384x384 | 335M |
| [internimage_h_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_h_jointto22k_384) | InternImage-H | Joint 427M -> IN-22K | 384x384 | 1.08B |
| [internimage_g_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_g_jointto22k_384) | InternImage-G | Joint 427M -> IN-22K | 384x384 | 3B |
### ImageNet-1K Image Classification
| huggingface name | model name | pretrain | resolution | acc@1 | #param | FLOPs |
| :---------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :---: | :----: | :---: |
| [internimage_t_1k_224](https://huggingface.co/OpenGVLab/internimage_t_1k_224) | InternImage-T | IN-1K | 224x224 | 83.5 | 30M | 5G |
| [internimage_s_1k_224](https://huggingface.co/OpenGVLab/internimage_s_1k_224) | InternImage-S | IN-1K | 224x224 | 84.2 | 50M | 8G |
| [internimage_b_1k_224](https://huggingface.co/OpenGVLab/internimage_b_1k_224) | InternImage-B | IN-1K | 224x224 | 84.9 | 97M | 16G |
| [internimage_l_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_l_22kto1k_384) | InternImage-L | IN-22K | 384x384 | 87.7 | 223M | 108G |
| [internimage_xl_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_xl_22kto1k_384) | InternImage-XL | IN-22K | 384x384 | 88.0 | 335M | 163G |
| [internimage_h_22kto1k_640](https://huggingface.co/OpenGVLab/internimage_h_22kto1k_640) | InternImage-H | Joint 427M -> IN-22K | 640x640 | 89.6 | 1.08B | 1478G |
| [internimage_g_22kto1k_512](https://huggingface.co/OpenGVLab/internimage_g_22kto1k_512) | InternImage-G | Joint 427M -> IN-22K | 512x512 | 90.1 | 3B | 2700G |
## DCNv3 CUDA Kernel Installation
If you do not install the CUDA version of DCNv3, InternImage will automatically fall back to a PyTorch implementation. However, the CUDA implementation can significantly reduce GPU memory usage and improve inference efficiency.
**Installation Tutorial:**
1. Open your terminal and run:
```bash
git clone https://github.com/OpenGVLab/InternImage.git
cd InternImage/classification/ops_dcnv3
```
2. Make sure you have an available GPU for compilation, then run:
```bash
sh make.sh
```
This will compile the CUDA version of DCNv3. Once installed, InternImage will automatically leverage the optimized CUDA implementation for better performance.
## Usage with Transformers
Below are two usage examples for InternImage with the Transformers framework:
### Example 1: Using InternImage as an Image Backbone
```python
import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
# Replace 'model_name' with the appropriate model identifier
model_name = "OpenGVLab/internimage_t_1k_224" # example model
# Prepare the image
image_path = 'img.png'
image_processor = CLIPImageProcessor.from_pretrained(model_name)
image = Image.open(image_path)
image = image_processor(images=image, return_tensors='pt').pixel_values
print('image shape:', image.shape)
# Load the model as a backbone
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# 'hidden_states' contains the outputs from the 4 stages of the InternImage backbone
hidden_states = model(image).hidden_states
```
### Example 2: Using InternImage for Image Classification
```python
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, CLIPImageProcessor
# Replace 'model_name' with the appropriate model identifier
model_name = "OpenGVLab/internimage_t_1k_224" # example model
# Prepare the image
image_path = 'img.png'
image_processor = CLIPImageProcessor.from_pretrained(model_name)
image = Image.open(image_path)
image = image_processor(images=image, return_tensors='pt').pixel_values
print('image shape:', image.shape)
# Load the model as an image classifier
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
logits = model(image).logits
label = torch.argmax(logits, dim=1)
print("Predicted label:", label.item())
```
## Citation
If this work is helpful for your research, please consider citing the following BibTeX entry.
```Bibtex
@inproceedings{wang2023internimage,
title={Internimage: Exploring large-scale vision foundation models with deformable convolutions},
author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
pages={14408--14419},
year={2023}
}
```
{
"_name_or_path": "OpenGVLab/internimage_t_1k_224",
"act_layer": "GELU",
"architectures": [
"InternImageModel"
],
"auto_map": {
"AutoConfig": "configuration_internimage.InternImageConfig",
"AutoModel": "modeling_internimage.InternImageModel",
"AutoModelForImageClassification": "modeling_internimage.InternImageModelForImageClassification"
},
"center_feature_scale": false,
"channels": 64,
"cls_scale": 1.5,
"core_op": "DCNv3",
"depths": [
4,
4,
18,
4
],
"drop_path_rate": 0.0,
"drop_path_type": "linear",
"drop_rate": 0.0,
"dw_kernel_size": null,
"groups": [
4,
8,
16,
32
],
"layer_scale": null,
"level2_post_norm": false,
"level2_post_norm_block_ids": null,
"mlp_ratio": 4.0,
"model_type": "internimage",
"norm_layer": "LN",
"num_classes": 1000,
"offset_scale": 1.0,
"post_norm": false,
"remove_center": false,
"res_post_norm": false,
"torch_dtype": "float32",
"transformers_version": "4.37.2",
"use_clip_projector": false,
"with_cp": false
}
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers import PretrainedConfig
class InternImageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`~InternImageModel`].
It is used to instantiate an internimage model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the internimage [OpenGVLab/internimage](https://huggingface.co/OpenGVLab/internimage) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
core_op (`str`, *optional*, defaults to `"DCNv3"`):
Core operation used in the InternImageModel.
depths (`tuple`, *optional*, defaults to `(4, 4, 18, 4)`):
Tuple specifying the depth of layers in the InternImageModel.
groups (`tuple`, *optional*, defaults to `(4, 8, 16, 32)`):
Tuple specifying the group of layers in the InternImageModel.
channels (`int`, *optional*, defaults to `64`):
Number of channels in the InternImageModel.
dw_kernel_size (`int`, *optional*, defaults to `None`):
Kernel size for depthwise convolutions.
layer_scale (`float`, *optional*, defaults to `None`):
Scale of the layers in the model.
offset_scale (`float`, *optional*, defaults to `1.0`):
Offset scale in the model.
mlp_ratio (`float`, *optional*, defaults to `4.0`):
Ratio of mlp layers in the InternImageModel.
post_norm (`bool`, *optional*, defaults to `False`):
Whether to use post normalization in the model.
level2_post_norm (`bool`, *optional*, defaults to `False`):
Whether to use level 2 post normalization.
level2_post_norm_block_ids (`list`, *optional*, defaults to `None`):
Specific block IDs for level 2 post normalization.
center_feature_scale (`bool`, *optional*, defaults to `False`):
Whether to apply center feature scaling.
use_clip_projector (`bool`, *optional*, defaults to `False`):
Whether to use CLIP projector.
remove_center (`bool`, *optional*, defaults to `False`):
Whether to remove center pixels in some operations.
num_classes (`int`, *optional*, defaults to `1000`):
Number of classes for the model output.
drop_rate (`float`, *optional*, defaults to `0.0`):
Dropout rate in the model.
drop_path_rate (`float`, *optional*, defaults to `0.0`):
Dropout path rate in the model.
drop_path_type (`str`, *optional*, defaults to `"linear"`):
Type of dropout path used in the model.
act_layer (`str`, *optional*, defaults to `"GELU"`):
Activation function used in the model.
norm_layer (`str`, *optional*, defaults to `"LN"`):
Normalization layer used in the model.
cls_scale (`float`, *optional*, defaults to `1.5`):
Scale of the classification layer in the model.
with_cp (`bool`, *optional*, defaults to `False`):
Whether to use checkpointing in the model.
"""
model_type = 'internimage'
def __init__(
self,
core_op='DCNv3',
depths=(4, 4, 18, 4),
groups=(4, 8, 16, 32),
channels=64,
dw_kernel_size=None,
layer_scale=None,
offset_scale=1.0,
mlp_ratio=4.0,
post_norm=False,
res_post_norm=False,
level2_post_norm=False,
level2_post_norm_block_ids=None,
center_feature_scale=False,
use_clip_projector=False,
remove_center=False,
num_classes=1000,
drop_rate=0.0,
drop_path_rate=0.0,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
cls_scale=1.5,
with_cp=False,
**kwargs,
):
super().__init__(**kwargs)
# Model configuration parameters
self.core_op = core_op
self.depths = depths
self.groups = groups
self.channels = channels
self.dw_kernel_size = dw_kernel_size
self.layer_scale = layer_scale
self.offset_scale = offset_scale
self.mlp_ratio = mlp_ratio
self.post_norm = post_norm
self.res_post_norm = res_post_norm
self.level2_post_norm = level2_post_norm
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.center_feature_scale = center_feature_scale
self.use_clip_projector = use_clip_projector
self.remove_center = remove_center
self.num_classes = num_classes
self.drop_rate = drop_rate
self.drop_path_rate = drop_path_rate
self.drop_path_type = drop_path_type
self.act_layer = act_layer
self.norm_layer = norm_layer
self.cls_scale = cls_scale
self.with_cp = with_cp
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import, division, print_function
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch, has_cuda_kernel
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class CenterFeatureScaleModule(nn.Module):
def forward(self,
query,
center_feature_scale_proj_weight,
center_feature_scale_proj_bias):
center_feature_scale = F.linear(query,
weight=center_feature_scale_proj_weight,
bias=center_feature_scale_proj_bias).sigmoid()
return center_feature_scale
class DCNv3_pytorch(nn.Module):
def __init__(
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
:param kernel_size
:param stride
:param pad
:param dilation
:param group
:param offset_scale
:param act_layer
:param norm_layer
"""
super().__init__()
if channels % group != 0:
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=dw_kernel_size,
stride=1,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
norm_layer,
'channels_first',
'channels_last'),
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.)
constant_(self.mask.weight.data, 0.)
constant_(self.mask.bias.data, 0.)
xavier_uniform_(self.input_proj.weight.data)
constant_(self.input_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, input):
"""
:param query (N, H, W, C)
:return output (N, H, W, C)
"""
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1)
x = dcnv3_core_pytorch(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale, self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
class DCNv3(nn.Module):
def __init__(
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
:param kernel_size
:param stride
:param pad
:param dilation
:param group
:param offset_scale
:param act_layer
:param norm_layer
"""
super().__init__()
if channels % group != 0:
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
if self.remove_center and self.kernel_size % 2 == 0:
raise ValueError('remove_center is only compatible with odd kernel size.')
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=dw_kernel_size,
stride=1,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
norm_layer,
'channels_first',
'channels_last'),
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.)
constant_(self.mask.weight.data, 0.)
constant_(self.mask.bias.data, 0.)
xavier_uniform_(self.input_proj.weight.data)
constant_(self.input_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, input):
"""
:param query (N, H, W, C)
:return output (N, H, W, C)
"""
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
dtype = x.dtype
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1)
mask = mask.reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import, division, print_function
try:
import DCNv3
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
has_cuda_kernel = True
except:
has_cuda_kernel = False
import pkg_resources
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
class DCNv3Function(Function):
@staticmethod
@custom_fwd
def forward(
ctx, input, offset, mask,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale, im2col_step, remove_center):
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.dilation_h = dilation_h
ctx.dilation_w = dilation_w
ctx.group = group
ctx.group_channels = group_channels
ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
args = [
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step
]
if remove_center or dcn_version > 1.0:
args.append(remove_center)
output = DCNv3.dcnv3_forward(*args)
ctx.save_for_backward(input, offset, mask)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors
args = [
input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step
]
if ctx.remove_center or dcn_version > 1.0:
args.append(ctx.remove_center)
grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward(*args)
return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, im2col_step, remove_center):
"""Symbolic function for mmdeploy::DCNv3.
Returns:
DCNv3 op for onnx.
"""
return g.op(
'mmdeploy::TRTDCNv3',
input,
offset,
mask,
kernel_h_i=int(kernel_h),
kernel_w_i=int(kernel_w),
stride_h_i=int(stride_h),
stride_w_i=int(stride_w),
pad_h_i=int(pad_h),
pad_w_i=int(pad_w),
dilation_h_i=int(dilation_h),
dilation_w_i=int(dilation_w),
group_i=int(group),
group_channels_i=int(group_channels),
offset_scale_f=float(offset_scale),
im2col_step_i=int(im2col_step),
remove_center_i=int(remove_center),
)
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
_, H_, W_, _ = spatial_shapes
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
ref_y, ref_x = torch.meshgrid(
torch.linspace(
# pad_h + 0.5,
# H_ - pad_h - 0.5,
(dilation_h * (kernel_h - 1)) // 2 + 0.5,
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
H_out,
dtype=torch.float32,
device=device),
torch.linspace(
# pad_w + 0.5,
# W_ - pad_w - 0.5,
(dilation_w * (kernel_w - 1)) // 2 + 0.5,
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
W_out,
dtype=torch.float32,
device=device))
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
ref = torch.stack((ref_x, ref_y), -1).reshape(
1, H_out, W_out, 1, 2)
return ref
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
_, H_, W_, _ = spatial_shapes
points_list = []
x, y = torch.meshgrid(
torch.linspace(
-((dilation_w * (kernel_w - 1)) // 2),
-((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w,
kernel_w,
dtype=torch.float32,
device=device),
torch.linspace(
-((dilation_h * (kernel_h - 1)) // 2),
-((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h,
kernel_h,
dtype=torch.float32,
device=device))
points_list.extend([x / W_, y / H_])
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
repeat(1, group, 1).permute(1, 0, 2)
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
return grid
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
idx = list(range(sampling_locations.shape[-2]))
C = (kernel_w * kernel_h - 1)//2
idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0]
sampling_locations = sampling_locations[:,:,:,idx, :]
return sampling_locations
def dcnv3_core_pytorch(
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, remove_center):
# for debug and test only,
# need to use cuda version instead
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
raise ValueError('remove_center is only compatible with square odd kernel size.')
input = F.pad(
input,
[0, 0, pad_h, pad_h, pad_w, pad_w])
N_, H_in, W_in, _ = input.shape
_, H_out, W_out, _ = offset.shape
ref = _get_reference_points(
input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)
grid = _generate_dilation_grids(
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device)
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
if remove_center:
sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
sampling_locations = sampling_locations.flatten(3, 4)
sampling_locations = sampling_locations + offset * offset_scale / spatial_norm
P_ = kernel_h * kernel_w - remove_center
sampling_grids = 2 * sampling_locations - 1
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
# (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_)
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\
reshape(N_*group, 1, H_out*W_out, P_)
output = (sampling_input_ * mask).sum(-1).view(N_,
group*group_channels, H_out*W_out)
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_internimage import InternImageConfig
from .dcnv3 import DCNv3, DCNv3_pytorch, has_cuda_kernel
from .dcnv3_func import dcnv3_core_pytorch
@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.
"""
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
pooler_output: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class CrossAttention(nn.Module):
r""" Cross Attention Module
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
attn_head_dim (int, optional): Dimension of attention head.
out_dim (int, optional): Dimension of output.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None,
out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
r"""Attentive Block
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float, optional): Dropout rate. Default: 0.0.
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
drop_path (float | tuple[float], optional): Stochastic depth rate.
Default: 0.0.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
attn_head_dim (int, optional): Dimension of attention head. Default: None.
out_dim (int, optional): Dimension of output. Default: None.
"""
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer='LN',
attn_head_dim=None,
out_dim=None):
super().__init__()
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
self.cross_dcn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
out_dim=out_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self,
x_q,
x_kv,
pos_q,
pos_k,
bool_masked_pos,
rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_dcn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv = x
pos_q, pos_k = 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k,
bool_masked_pos=None,
rel_pos_bias=None)
x = x.squeeze(1)
return x
class StemLayer(nn.Module):
r"""Stem layer of InternImage
Args:
in_chans (int): number of input channels
out_chans (int): number of output channels
act_layer (str): activation layer
norm_layer (str): normalization layer
"""
def __init__(self,
in_chans=3,
out_chans=96,
act_layer='GELU',
norm_layer='BN'):
super().__init__()
self.conv1 = nn.Conv2d(in_chans,
out_chans // 2,
kernel_size=3,
stride=2,
padding=1)
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
'channels_first', 'channels_first')
self.act = build_act_layer(act_layer)
self.conv2 = nn.Conv2d(out_chans // 2,
out_chans,
kernel_size=3,
stride=2,
padding=1)
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
'channels_last')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class DownsampleLayer(nn.Module):
r"""Downsample layer of InternImage
Args:
channels (int): number of input channels
norm_layer (str): normalization layer
"""
def __init__(self, channels, norm_layer='LN'):
super().__init__()
self.conv = nn.Conv2d(channels,
2 * channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm = build_norm_layer(2 * channels, norm_layer,
'channels_first', 'channels_last')
def forward(self, x):
x = self.conv(x.permute(0, 3, 1, 2))
x = self.norm(x)
return x
class MLPLayer(nn.Module):
r"""MLP layer of InternImage
Args:
in_features (int): number of input features
hidden_features (int): number of hidden features
out_features (int): number of output features
act_layer (str): activation layer
drop (float): dropout rate
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer='GELU',
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_act_layer(act_layer)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InternImageLayer(nn.Module):
r"""Basic layer of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
groups,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
layer_scale=None,
offset_scale=1.0,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.groups = groups
self.mlp_ratio = mlp_ratio
self.with_cp = with_cp
self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm
self.dcn = core_op(
channels=channels,
kernel_size=3,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
self.mlp = MLPLayer(in_features=channels,
hidden_features=int(channels * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.layer_scale = layer_scale is not None
if self.layer_scale:
self.layer_scale1 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.layer_scale2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.res_post_norm = res_post_norm
if res_post_norm:
self.res_post_norm1 = build_norm_layer(channels, 'LN')
self.res_post_norm2 = build_norm_layer(channels, 'LN')
def forward(self, x):
def _inner_forward(x):
if not self.layer_scale:
if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm: # for InternImage-H/G
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else:
x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.post_norm:
x = x + self.drop_path(self.layer_scale1 * self.norm1(self.dcn(x)))
x = x + self.drop_path(self.layer_scale2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.layer_scale1 * self.dcn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale2 * self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = checkpoint.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class InternImageBlock(nn.Module):
r"""Block of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
depths (list): Depth of each block.
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
depth,
groups,
downsample=True,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
offset_scale=1.0,
layer_scale=None,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.depth = depth
self.post_norm = post_norm
self.center_feature_scale = center_feature_scale
self.blocks = nn.ModuleList([
InternImageLayer(
core_op=core_op,
channels=channels,
groups=groups,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
) for i in range(depth)
])
if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None: # for InternImage-H/G
self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
)
self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False):
for i, blk in enumerate(self.blocks):
x = blk(x)
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale:
x = self.norm(x)
if return_wo_downsample:
x_ = x
if self.downsample is not None:
x = self.downsample(x)
if return_wo_downsample:
return x, x_
return x
class InternImage(nn.Module):
r"""InternImage
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
https://arxiv.org/pdf/2103.14030
Args:
core_op (str): Core operator. Default: 'DCNv3'
channels (int): Number of the first stage. Default: 64
depths (list): Depth of each block. Default: [3, 4, 18, 5]
groups (list): Groups of each block. Default: [3, 6, 12, 24]
num_classes (int): Number of classes. Default: 1000
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
drop_rate (float): Probability of an element to be zeroed. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.
act_layer (str): Activation layer. Default: 'GELU'
norm_layer (str): Normalization layer. Default: 'LN'
layer_scale (float): The initial value of layer scale. Default: None
cls_scale (float): Whether to use class scale. Default: 1.5
with_cp (bool): Use gradient checkpointing or not. Default: False
dw_kernel_size (int): Size of the dwconv. Default: None
use_clip_projector (bool): Whether to use clip projector. Default: False
level2_post_norm (bool): Whether to use level2 post norm. Default: False
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
res_post_norm (bool): Whether to use res post norm. Default: False
center_feature_scale (bool): Whether to use center feature scale. Default: False
"""
def __init__(self,
core_op='DCNv3',
channels=64,
depths=[3, 4, 18, 5],
groups=[3, 6, 12, 24],
num_classes=1000,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.2,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
layer_scale=None,
offset_scale=1.0,
post_norm=False,
cls_scale=1.5,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
use_clip_projector=False, # for InternImage-H/G
level2_post_norm=False, # for InternImage-H/G
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
**kwargs):
super().__init__()
if core_op == 'DCNv3' and has_cuda_kernel:
self.core_op = DCNv3
print('DCNv3 is installed, using CUDA implementation.')
elif core_op == 'DCNv3' and not has_cuda_kernel:
self.core_op = DCNv3_pytorch
print('DCNv3 is not installed, using PyTorch implementation.')
else:
self.core_op = DCNv3_pytorch
print('Using DCNv3 PyTorch implementation.')
self.num_classes = num_classes
self.num_levels = len(depths)
self.depths = depths
self.channels = channels
self.num_features = int(channels * 2 ** (self.num_levels - 1))
self.post_norm = post_norm
self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.remove_center = remove_center
print(f'using core type: {core_op}')
print(f'level2_post_norm: {level2_post_norm}')
print(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}')
print(f'res_post_norm: {res_post_norm}')
print(f'remove_center: {remove_center}')
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
out_chans=channels,
act_layer=act_layer,
norm_layer=norm_layer)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
if drop_path_type == 'uniform':
for i in range(len(dpr)):
dpr[i] = drop_path_rate
self.levels = nn.ModuleList()
for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G
level = InternImageBlock(
core_op=self.core_op,
channels=int(channels * 2 ** i),
depth=depths[i],
groups=groups[i],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.levels.append(level)
if self.num_classes > 0:
if not use_clip_projector: # for InternImage-T/S/B/L/XL
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
else: # for InternImage-H/G
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768
self.dcnv3_head_x4 = nn.Sequential(
nn.Conv2d(in_channels=self.num_features,
out_channels=pretrain_embed_dim * (_stride ** 2),
kernel_size=1), nn.PixelShuffle(_stride))
self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2,
out_channels=pretrain_embed_dim,
kernel_size=1)
self.clip_projector = AttentionPoolingBlock(
dim=pretrain_embed_dim,
num_heads=attnpool_num_heads,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
norm_layer=norm_layer,
out_dim=clip_embed_dim)
self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6)
self.head = nn.Linear(
clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m):
if isinstance(m, self.core_op):
m._reset_parameters()
@torch.jit.ignore
def lr_decay_keywords(self, decay_ratio=0.87):
lr_ratios = {}
# blocks
idx = 0
for i in range(4):
layer_num = 3 - i # 3 2 1 0
for j in range(self.depths[layer_num]):
block_num = self.depths[layer_num] - j - 1
tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
decay = 1.0 * (decay_ratio ** idx)
lr_ratios[tag] = decay
idx += 1
# patch_embed (before stage-1)
lr_ratios['patch_embed'] = lr_ratios['levels.0.blocks.0.']
# levels.0.downsample (between stage-1 and stage-2)
lr_ratios['levels.0.downsample'] = lr_ratios['levels.1.blocks.0.']
lr_ratios['levels.0.norm'] = lr_ratios['levels.1.blocks.0.']
# levels.1.downsample (between stage-2 and stage-3)
lr_ratios['levels.1.downsample'] = lr_ratios['levels.2.blocks.0.']
lr_ratios['levels.1.norm'] = lr_ratios['levels.2.blocks.0.']
# levels.2.downsample (between stage-3 and stage-4)
lr_ratios['levels.2.downsample'] = lr_ratios['levels.3.blocks.0.']
lr_ratios['levels.2.norm'] = lr_ratios['levels.3.blocks.0.']
return lr_ratios
def forward_features_seq_out(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
seq_out = []
for level in self.levels:
x, x_ = level(x, return_wo_downsample=True)
seq_out.append(x_)
return seq_out
def forward_features(self, x):
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
hidden_states = [x1, x2, x3, x4]
if self.num_classes > 0:
x = self.conv_head(x4)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return {
'hidden_states': hidden_states,
'pooler_output': x if self.num_classes > 0 else None
}
def forward_clip_projector(self, x): # for InternImage-H/G
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
hidden_states = [x1, x2, x3, x4]
if self.num_classes > 0:
x4 = self.dcnv3_head_x4(x4)
x = x4
x3 = self.dcnv3_head_x3(x3)
x = x + x3
x = x.flatten(-2).transpose(1, 2).contiguous()
x = self.clip_projector(x)
x = self.fc_norm(x)
return {
'hidden_states': hidden_states,
'pooler_output': x if self.num_classes > 0 else None
}
def forward(self, x):
if self.use_clip_projector: # for InternImage-H/G
outputs = self.forward_clip_projector(x)
else: # for InternImage-T/S/B/L/XL
outputs = self.forward_features(x)
hidden_states = outputs['hidden_states']
pooler_output = outputs['pooler_output']
if self.num_classes > 0:
logits = self.head(pooler_output)
else:
logits = None
return BackboneOutput(
hidden_states=hidden_states,
last_hidden_state=hidden_states[-1],
pooler_output=pooler_output,
logits=logits
)
class InternImageModel(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=0,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
res_post_norm=config.res_post_norm, # for InternImage-H/G
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
class InternImageModelForImageClassification(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
res_post_norm=config.res_post_norm, # for InternImage-H/G
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
if labels is not None:
logits = outputs['logits']
loss = F.cross_entropy(logits, labels)
outputs['loss'] = loss
return outputs
{
"crop_size": 224,
"do_center_crop": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.485,
0.456,
0.406
],
"image_std": [
0.229,
0.224,
0.225
],
"resample": 3,
"size": 224
}
---
license: mit
pipeline_tag: image-classification
library_name: transformers
tags:
- internimage
- custom_code
datasets:
- ILSVRC/imagenet-1k
---
# InternImage Model Card
## Introduction
InternImage is an advanced vision foundation model developed by researchers from Shanghai AI Laboratory, Tsinghua University, and other institutions. Unlike models based on Transformers, InternImage employs DCNv3 as its core operator. This approach equips the model with dynamic and effective receptive fields required for downstream tasks like object detection and segmentation, while enabling adaptive spatial aggregation.
<div style="text-align: center;"> <img src="https://github.com/OpenGVLab/InternImage/raw/master/docs/figs/arch.png" style="width:60%;" /> </div>
## Performance
- InternImage achieved an impressive Top-1 accuracy of 90.1% on the ImageNet benchmark dataset using only publicly available data for image classification. Apart from two undisclosed models trained with additional datasets by Google and Microsoft, InternImage is the only open-source model that achieves a Top-1 accuracy of over 90.0%, and it is also the largest model in scale worldwide.
- InternImage outperformed all other models worldwide on the COCO object detection benchmark dataset with a remarkable mAP of 65.5, making it the only model that surpasses 65 mAP in the world.
- InternImage also demonstrated world's best performance on 16 other important visual benchmark datasets, covering a wide range of tasks such as classification, detection, and segmentation, making it the top-performing model across multiple domains.
## Released Models
### Open‑Source Visual Pretrained Models
| huggingface name | model name | pretrain | resolution | #param |
| :-------------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :----: |
| [internimage_l_22k_384](https://huggingface.co/OpenGVLab/internimage_l_22k_384) | InternImage-L | IN-22K | 384x384 | 223M |
| [internimage_xl_22k_384](https://huggingface.co/OpenGVLab/internimage_xl_22k_384) | InternImage-XL | IN-22K | 384x384 | 335M |
| [internimage_h_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_h_jointto22k_384) | InternImage-H | Joint 427M -> IN-22K | 384x384 | 1.08B |
| [internimage_g_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_g_jointto22k_384) | InternImage-G | Joint 427M -> IN-22K | 384x384 | 3B |
### ImageNet-1K Image Classification
| huggingface name | model name | pretrain | resolution | acc@1 | #param | FLOPs |
| :---------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :---: | :----: | :---: |
| [internimage_t_1k_224](https://huggingface.co/OpenGVLab/internimage_t_1k_224) | InternImage-T | IN-1K | 224x224 | 83.5 | 30M | 5G |
| [internimage_s_1k_224](https://huggingface.co/OpenGVLab/internimage_s_1k_224) | InternImage-S | IN-1K | 224x224 | 84.2 | 50M | 8G |
| [internimage_b_1k_224](https://huggingface.co/OpenGVLab/internimage_b_1k_224) | InternImage-B | IN-1K | 224x224 | 84.9 | 97M | 16G |
| [internimage_l_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_l_22kto1k_384) | InternImage-L | IN-22K | 384x384 | 87.7 | 223M | 108G |
| [internimage_xl_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_xl_22kto1k_384) | InternImage-XL | IN-22K | 384x384 | 88.0 | 335M | 163G |
| [internimage_h_22kto1k_640](https://huggingface.co/OpenGVLab/internimage_h_22kto1k_640) | InternImage-H | Joint 427M -> IN-22K | 640x640 | 89.6 | 1.08B | 1478G |
| [internimage_g_22kto1k_512](https://huggingface.co/OpenGVLab/internimage_g_22kto1k_512) | InternImage-G | Joint 427M -> IN-22K | 512x512 | 90.1 | 3B | 2700G |
## DCNv3 CUDA Kernel Installation
If you do not install the CUDA version of DCNv3, InternImage will automatically fall back to a PyTorch implementation. However, the CUDA implementation can significantly reduce GPU memory usage and improve inference efficiency.
**Installation Tutorial:**
1. Open your terminal and run:
```bash
git clone https://github.com/OpenGVLab/InternImage.git
cd InternImage/classification/ops_dcnv3
```
2. Make sure you have an available GPU for compilation, then run:
```bash
sh make.sh
```
This will compile the CUDA version of DCNv3. Once installed, InternImage will automatically leverage the optimized CUDA implementation for better performance.
## Usage with Transformers
Below are two usage examples for InternImage with the Transformers framework:
### Example 1: Using InternImage as an Image Backbone
```python
import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
# Replace 'model_name' with the appropriate model identifier
model_name = "OpenGVLab/internimage_t_1k_224" # example model
# Prepare the image
image_path = 'img.png'
image_processor = CLIPImageProcessor.from_pretrained(model_name)
image = Image.open(image_path)
image = image_processor(images=image, return_tensors='pt').pixel_values
print('image shape:', image.shape)
# Load the model as a backbone
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# 'hidden_states' contains the outputs from the 4 stages of the InternImage backbone
hidden_states = model(image).hidden_states
```
### Example 2: Using InternImage for Image Classification
```python
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, CLIPImageProcessor
# Replace 'model_name' with the appropriate model identifier
model_name = "OpenGVLab/internimage_t_1k_224" # example model
# Prepare the image
image_path = 'img.png'
image_processor = CLIPImageProcessor.from_pretrained(model_name)
image = Image.open(image_path)
image = image_processor(images=image, return_tensors='pt').pixel_values
print('image shape:', image.shape)
# Load the model as an image classifier
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
logits = model(image).logits
label = torch.argmax(logits, dim=1)
print("Predicted label:", label.item())
```
## Citation
If this work is helpful for your research, please consider citing the following BibTeX entry.
```Bibtex
@inproceedings{wang2023internimage,
title={Internimage: Exploring large-scale vision foundation models with deformable convolutions},
author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
pages={14408--14419},
year={2023}
}
```
{
"_name_or_path": "OpenGVLab/internimage_xl_22kto1k_384",
"act_layer": "GELU",
"architectures": [
"InternImageModel"
],
"auto_map": {
"AutoConfig": "configuration_internimage.InternImageConfig",
"AutoModel": "modeling_internimage.InternImageModel",
"AutoModelForImageClassification": "modeling_internimage.InternImageModelForImageClassification"
},
"center_feature_scale": false,
"channels": 192,
"cls_scale": 1.5,
"core_op": "DCNv3",
"depths": [
5,
5,
24,
5
],
"drop_path_rate": 0.0,
"drop_path_type": "linear",
"drop_rate": 0.0,
"dw_kernel_size": null,
"groups": [
12,
24,
48,
96
],
"layer_scale": 1e-05,
"level2_post_norm": false,
"level2_post_norm_block_ids": null,
"mlp_ratio": 4.0,
"model_type": "internimage",
"norm_layer": "LN",
"num_classes": 1000,
"offset_scale": 2.0,
"post_norm": true,
"remove_center": false,
"res_post_norm": false,
"torch_dtype": "float32",
"transformers_version": "4.37.2",
"use_clip_projector": false,
"with_cp": false
}
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers import PretrainedConfig
class InternImageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`~InternImageModel`].
It is used to instantiate an internimage model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the internimage [OpenGVLab/internimage](https://huggingface.co/OpenGVLab/internimage) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
core_op (`str`, *optional*, defaults to `"DCNv3"`):
Core operation used in the InternImageModel.
depths (`tuple`, *optional*, defaults to `(4, 4, 18, 4)`):
Tuple specifying the depth of layers in the InternImageModel.
groups (`tuple`, *optional*, defaults to `(4, 8, 16, 32)`):
Tuple specifying the group of layers in the InternImageModel.
channels (`int`, *optional*, defaults to `64`):
Number of channels in the InternImageModel.
dw_kernel_size (`int`, *optional*, defaults to `None`):
Kernel size for depthwise convolutions.
layer_scale (`float`, *optional*, defaults to `None`):
Scale of the layers in the model.
offset_scale (`float`, *optional*, defaults to `1.0`):
Offset scale in the model.
mlp_ratio (`float`, *optional*, defaults to `4.0`):
Ratio of mlp layers in the InternImageModel.
post_norm (`bool`, *optional*, defaults to `False`):
Whether to use post normalization in the model.
level2_post_norm (`bool`, *optional*, defaults to `False`):
Whether to use level 2 post normalization.
level2_post_norm_block_ids (`list`, *optional*, defaults to `None`):
Specific block IDs for level 2 post normalization.
center_feature_scale (`bool`, *optional*, defaults to `False`):
Whether to apply center feature scaling.
use_clip_projector (`bool`, *optional*, defaults to `False`):
Whether to use CLIP projector.
remove_center (`bool`, *optional*, defaults to `False`):
Whether to remove center pixels in some operations.
num_classes (`int`, *optional*, defaults to `1000`):
Number of classes for the model output.
drop_rate (`float`, *optional*, defaults to `0.0`):
Dropout rate in the model.
drop_path_rate (`float`, *optional*, defaults to `0.0`):
Dropout path rate in the model.
drop_path_type (`str`, *optional*, defaults to `"linear"`):
Type of dropout path used in the model.
act_layer (`str`, *optional*, defaults to `"GELU"`):
Activation function used in the model.
norm_layer (`str`, *optional*, defaults to `"LN"`):
Normalization layer used in the model.
cls_scale (`float`, *optional*, defaults to `1.5`):
Scale of the classification layer in the model.
with_cp (`bool`, *optional*, defaults to `False`):
Whether to use checkpointing in the model.
"""
model_type = 'internimage'
def __init__(
self,
core_op='DCNv3',
depths=(4, 4, 18, 4),
groups=(4, 8, 16, 32),
channels=64,
dw_kernel_size=None,
layer_scale=None,
offset_scale=1.0,
mlp_ratio=4.0,
post_norm=False,
res_post_norm=False,
level2_post_norm=False,
level2_post_norm_block_ids=None,
center_feature_scale=False,
use_clip_projector=False,
remove_center=False,
num_classes=1000,
drop_rate=0.0,
drop_path_rate=0.0,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
cls_scale=1.5,
with_cp=False,
**kwargs,
):
super().__init__(**kwargs)
# Model configuration parameters
self.core_op = core_op
self.depths = depths
self.groups = groups
self.channels = channels
self.dw_kernel_size = dw_kernel_size
self.layer_scale = layer_scale
self.offset_scale = offset_scale
self.mlp_ratio = mlp_ratio
self.post_norm = post_norm
self.res_post_norm = res_post_norm
self.level2_post_norm = level2_post_norm
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.center_feature_scale = center_feature_scale
self.use_clip_projector = use_clip_projector
self.remove_center = remove_center
self.num_classes = num_classes
self.drop_rate = drop_rate
self.drop_path_rate = drop_path_rate
self.drop_path_type = drop_path_type
self.act_layer = act_layer
self.norm_layer = norm_layer
self.cls_scale = cls_scale
self.with_cp = with_cp
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import, division, print_function
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch, has_cuda_kernel
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class CenterFeatureScaleModule(nn.Module):
def forward(self,
query,
center_feature_scale_proj_weight,
center_feature_scale_proj_bias):
center_feature_scale = F.linear(query,
weight=center_feature_scale_proj_weight,
bias=center_feature_scale_proj_bias).sigmoid()
return center_feature_scale
class DCNv3_pytorch(nn.Module):
def __init__(
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
:param kernel_size
:param stride
:param pad
:param dilation
:param group
:param offset_scale
:param act_layer
:param norm_layer
"""
super().__init__()
if channels % group != 0:
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=dw_kernel_size,
stride=1,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
norm_layer,
'channels_first',
'channels_last'),
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.)
constant_(self.mask.weight.data, 0.)
constant_(self.mask.bias.data, 0.)
xavier_uniform_(self.input_proj.weight.data)
constant_(self.input_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, input):
"""
:param query (N, H, W, C)
:return output (N, H, W, C)
"""
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1)
x = dcnv3_core_pytorch(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale, self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
class DCNv3(nn.Module):
def __init__(
self,
channels=64,
kernel_size=3,
dw_kernel_size=None,
stride=1,
pad=1,
dilation=1,
group=4,
offset_scale=1.0,
act_layer='GELU',
norm_layer='LN',
center_feature_scale=False,
remove_center=False,
):
"""
DCNv3 Module
:param channels
:param kernel_size
:param stride
:param pad
:param dilation
:param group
:param offset_scale
:param act_layer
:param norm_layer
"""
super().__init__()
if channels % group != 0:
raise ValueError(
f'channels must be divisible by group, but got {channels} and {group}')
_d_per_group = channels // group
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
self.kernel_size = kernel_size
self.dw_kernel_size = dw_kernel_size
self.stride = stride
self.dilation = dilation
self.pad = pad
self.group = group
self.group_channels = channels // group
self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale
self.remove_center = int(remove_center)
if self.remove_center and self.kernel_size % 2 == 0:
raise ValueError('remove_center is only compatible with odd kernel size.')
self.dw_conv = nn.Sequential(
nn.Conv2d(
channels,
channels,
kernel_size=dw_kernel_size,
stride=1,
padding=(dw_kernel_size - 1) // 2,
groups=channels),
build_norm_layer(
channels,
norm_layer,
'channels_first',
'channels_last'),
build_act_layer(act_layer))
self.offset = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(
channels,
group * (kernel_size * kernel_size - remove_center))
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float))
self.center_feature_scale_proj_bias = nn.Parameter(
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
self.center_feature_scale_module = CenterFeatureScaleModule()
def _reset_parameters(self):
constant_(self.offset.weight.data, 0.)
constant_(self.offset.bias.data, 0.)
constant_(self.mask.weight.data, 0.)
constant_(self.mask.bias.data, 0.)
xavier_uniform_(self.input_proj.weight.data)
constant_(self.input_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, input):
"""
:param query (N, H, W, C)
:return output (N, H, W, C)
"""
N, H, W, _ = input.shape
x = self.input_proj(input)
x_proj = x
dtype = x.dtype
x1 = input.permute(0, 3, 1, 2)
x1 = self.dw_conv(x1)
offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1)
mask = mask.reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
self.remove_center)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = self.output_proj(x)
return x
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import, division, print_function
try:
import DCNv3
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
has_cuda_kernel = True
except:
has_cuda_kernel = False
import pkg_resources
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
class DCNv3Function(Function):
@staticmethod
@custom_fwd
def forward(
ctx, input, offset, mask,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale, im2col_step, remove_center):
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.dilation_h = dilation_h
ctx.dilation_w = dilation_w
ctx.group = group
ctx.group_channels = group_channels
ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
args = [
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step
]
if remove_center or dcn_version > 1.0:
args.append(remove_center)
output = DCNv3.dcnv3_forward(*args)
ctx.save_for_backward(input, offset, mask)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors
args = [
input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step
]
if ctx.remove_center or dcn_version > 1.0:
args.append(ctx.remove_center)
grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward(*args)
return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, im2col_step, remove_center):
"""Symbolic function for mmdeploy::DCNv3.
Returns:
DCNv3 op for onnx.
"""
return g.op(
'mmdeploy::TRTDCNv3',
input,
offset,
mask,
kernel_h_i=int(kernel_h),
kernel_w_i=int(kernel_w),
stride_h_i=int(stride_h),
stride_w_i=int(stride_w),
pad_h_i=int(pad_h),
pad_w_i=int(pad_w),
dilation_h_i=int(dilation_h),
dilation_w_i=int(dilation_w),
group_i=int(group),
group_channels_i=int(group_channels),
offset_scale_f=float(offset_scale),
im2col_step_i=int(im2col_step),
remove_center_i=int(remove_center),
)
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
_, H_, W_, _ = spatial_shapes
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
ref_y, ref_x = torch.meshgrid(
torch.linspace(
# pad_h + 0.5,
# H_ - pad_h - 0.5,
(dilation_h * (kernel_h - 1)) // 2 + 0.5,
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
H_out,
dtype=torch.float32,
device=device),
torch.linspace(
# pad_w + 0.5,
# W_ - pad_w - 0.5,
(dilation_w * (kernel_w - 1)) // 2 + 0.5,
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
W_out,
dtype=torch.float32,
device=device))
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
ref = torch.stack((ref_x, ref_y), -1).reshape(
1, H_out, W_out, 1, 2)
return ref
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
_, H_, W_, _ = spatial_shapes
points_list = []
x, y = torch.meshgrid(
torch.linspace(
-((dilation_w * (kernel_w - 1)) // 2),
-((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w,
kernel_w,
dtype=torch.float32,
device=device),
torch.linspace(
-((dilation_h * (kernel_h - 1)) // 2),
-((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h,
kernel_h,
dtype=torch.float32,
device=device))
points_list.extend([x / W_, y / H_])
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
repeat(1, group, 1).permute(1, 0, 2)
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
return grid
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
idx = list(range(sampling_locations.shape[-2]))
C = (kernel_w * kernel_h - 1)//2
idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0]
sampling_locations = sampling_locations[:,:,:,idx, :]
return sampling_locations
def dcnv3_core_pytorch(
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, remove_center):
# for debug and test only,
# need to use cuda version instead
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
raise ValueError('remove_center is only compatible with square odd kernel size.')
input = F.pad(
input,
[0, 0, pad_h, pad_h, pad_w, pad_w])
N_, H_in, W_in, _ = input.shape
_, H_out, W_out, _ = offset.shape
ref = _get_reference_points(
input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)
grid = _generate_dilation_grids(
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device)
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
if remove_center:
sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
sampling_locations = sampling_locations.flatten(3, 4)
sampling_locations = sampling_locations + offset * offset_scale / spatial_norm
P_ = kernel_h * kernel_w - remove_center
sampling_grids = 2 * sampling_locations - 1
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
# (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_)
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\
reshape(N_*group, 1, H_out*W_out, P_)
output = (sampling_input_ * mask).sum(-1).view(N_,
group*group_channels, H_out*W_out)
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()
# --------------------------------------------------------
# InternImage
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_internimage import InternImageConfig
from .dcnv3 import DCNv3, DCNv3_pytorch, has_cuda_kernel
from .dcnv3_func import dcnv3_core_pytorch
@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.
"""
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
pooler_output: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class CrossAttention(nn.Module):
r""" Cross Attention Module
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
attn_head_dim (int, optional): Dimension of attention head.
out_dim (int, optional): Dimension of output.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None,
out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
r"""Attentive Block
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float, optional): Dropout rate. Default: 0.0.
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
drop_path (float | tuple[float], optional): Stochastic depth rate.
Default: 0.0.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
attn_head_dim (int, optional): Dimension of attention head. Default: None.
out_dim (int, optional): Dimension of output. Default: None.
"""
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer='LN',
attn_head_dim=None,
out_dim=None):
super().__init__()
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
self.cross_dcn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
out_dim=out_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self,
x_q,
x_kv,
pos_q,
pos_k,
bool_masked_pos,
rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_dcn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv = x
pos_q, pos_k = 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k,
bool_masked_pos=None,
rel_pos_bias=None)
x = x.squeeze(1)
return x
class StemLayer(nn.Module):
r"""Stem layer of InternImage
Args:
in_chans (int): number of input channels
out_chans (int): number of output channels
act_layer (str): activation layer
norm_layer (str): normalization layer
"""
def __init__(self,
in_chans=3,
out_chans=96,
act_layer='GELU',
norm_layer='BN'):
super().__init__()
self.conv1 = nn.Conv2d(in_chans,
out_chans // 2,
kernel_size=3,
stride=2,
padding=1)
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
'channels_first', 'channels_first')
self.act = build_act_layer(act_layer)
self.conv2 = nn.Conv2d(out_chans // 2,
out_chans,
kernel_size=3,
stride=2,
padding=1)
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
'channels_last')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class DownsampleLayer(nn.Module):
r"""Downsample layer of InternImage
Args:
channels (int): number of input channels
norm_layer (str): normalization layer
"""
def __init__(self, channels, norm_layer='LN'):
super().__init__()
self.conv = nn.Conv2d(channels,
2 * channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm = build_norm_layer(2 * channels, norm_layer,
'channels_first', 'channels_last')
def forward(self, x):
x = self.conv(x.permute(0, 3, 1, 2))
x = self.norm(x)
return x
class MLPLayer(nn.Module):
r"""MLP layer of InternImage
Args:
in_features (int): number of input features
hidden_features (int): number of hidden features
out_features (int): number of output features
act_layer (str): activation layer
drop (float): dropout rate
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer='GELU',
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_act_layer(act_layer)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InternImageLayer(nn.Module):
r"""Basic layer of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
groups,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
layer_scale=None,
offset_scale=1.0,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.groups = groups
self.mlp_ratio = mlp_ratio
self.with_cp = with_cp
self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm
self.dcn = core_op(
channels=channels,
kernel_size=3,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
self.mlp = MLPLayer(in_features=channels,
hidden_features=int(channels * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.layer_scale = layer_scale is not None
if self.layer_scale:
self.layer_scale1 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.layer_scale2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.res_post_norm = res_post_norm
if res_post_norm:
self.res_post_norm1 = build_norm_layer(channels, 'LN')
self.res_post_norm2 = build_norm_layer(channels, 'LN')
def forward(self, x):
def _inner_forward(x):
if not self.layer_scale:
if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm: # for InternImage-H/G
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else:
x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.post_norm:
x = x + self.drop_path(self.layer_scale1 * self.norm1(self.dcn(x)))
x = x + self.drop_path(self.layer_scale2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.layer_scale1 * self.dcn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale2 * self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = checkpoint.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class InternImageBlock(nn.Module):
r"""Block of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
depths (list): Depth of each block.
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
depth,
groups,
downsample=True,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
offset_scale=1.0,
layer_scale=None,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.depth = depth
self.post_norm = post_norm
self.center_feature_scale = center_feature_scale
self.blocks = nn.ModuleList([
InternImageLayer(
core_op=core_op,
channels=channels,
groups=groups,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
) for i in range(depth)
])
if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None: # for InternImage-H/G
self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
)
self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False):
for i, blk in enumerate(self.blocks):
x = blk(x)
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale:
x = self.norm(x)
if return_wo_downsample:
x_ = x
if self.downsample is not None:
x = self.downsample(x)
if return_wo_downsample:
return x, x_
return x
class InternImage(nn.Module):
r"""InternImage
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
https://arxiv.org/pdf/2103.14030
Args:
core_op (str): Core operator. Default: 'DCNv3'
channels (int): Number of the first stage. Default: 64
depths (list): Depth of each block. Default: [3, 4, 18, 5]
groups (list): Groups of each block. Default: [3, 6, 12, 24]
num_classes (int): Number of classes. Default: 1000
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
drop_rate (float): Probability of an element to be zeroed. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.
act_layer (str): Activation layer. Default: 'GELU'
norm_layer (str): Normalization layer. Default: 'LN'
layer_scale (float): The initial value of layer scale. Default: None
cls_scale (float): Whether to use class scale. Default: 1.5
with_cp (bool): Use gradient checkpointing or not. Default: False
dw_kernel_size (int): Size of the dwconv. Default: None
use_clip_projector (bool): Whether to use clip projector. Default: False
level2_post_norm (bool): Whether to use level2 post norm. Default: False
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
res_post_norm (bool): Whether to use res post norm. Default: False
center_feature_scale (bool): Whether to use center feature scale. Default: False
"""
def __init__(self,
core_op='DCNv3',
channels=64,
depths=[3, 4, 18, 5],
groups=[3, 6, 12, 24],
num_classes=1000,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.2,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
layer_scale=None,
offset_scale=1.0,
post_norm=False,
cls_scale=1.5,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
use_clip_projector=False, # for InternImage-H/G
level2_post_norm=False, # for InternImage-H/G
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
**kwargs):
super().__init__()
if core_op == 'DCNv3' and has_cuda_kernel:
self.core_op = DCNv3
print('DCNv3 is installed, using CUDA implementation.')
elif core_op == 'DCNv3' and not has_cuda_kernel:
self.core_op = DCNv3_pytorch
print('DCNv3 is not installed, using PyTorch implementation.')
else:
self.core_op = DCNv3_pytorch
print('Using DCNv3 PyTorch implementation.')
self.num_classes = num_classes
self.num_levels = len(depths)
self.depths = depths
self.channels = channels
self.num_features = int(channels * 2 ** (self.num_levels - 1))
self.post_norm = post_norm
self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.remove_center = remove_center
print(f'using core type: {core_op}')
print(f'level2_post_norm: {level2_post_norm}')
print(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}')
print(f'res_post_norm: {res_post_norm}')
print(f'remove_center: {remove_center}')
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
out_chans=channels,
act_layer=act_layer,
norm_layer=norm_layer)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
if drop_path_type == 'uniform':
for i in range(len(dpr)):
dpr[i] = drop_path_rate
self.levels = nn.ModuleList()
for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G
level = InternImageBlock(
core_op=self.core_op,
channels=int(channels * 2 ** i),
depth=depths[i],
groups=groups[i],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.levels.append(level)
if self.num_classes > 0:
if not use_clip_projector: # for InternImage-T/S/B/L/XL
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
else: # for InternImage-H/G
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768
self.dcnv3_head_x4 = nn.Sequential(
nn.Conv2d(in_channels=self.num_features,
out_channels=pretrain_embed_dim * (_stride ** 2),
kernel_size=1), nn.PixelShuffle(_stride))
self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2,
out_channels=pretrain_embed_dim,
kernel_size=1)
self.clip_projector = AttentionPoolingBlock(
dim=pretrain_embed_dim,
num_heads=attnpool_num_heads,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
norm_layer=norm_layer,
out_dim=clip_embed_dim)
self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6)
self.head = nn.Linear(
clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m):
if isinstance(m, self.core_op):
m._reset_parameters()
@torch.jit.ignore
def lr_decay_keywords(self, decay_ratio=0.87):
lr_ratios = {}
# blocks
idx = 0
for i in range(4):
layer_num = 3 - i # 3 2 1 0
for j in range(self.depths[layer_num]):
block_num = self.depths[layer_num] - j - 1
tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
decay = 1.0 * (decay_ratio ** idx)
lr_ratios[tag] = decay
idx += 1
# patch_embed (before stage-1)
lr_ratios['patch_embed'] = lr_ratios['levels.0.blocks.0.']
# levels.0.downsample (between stage-1 and stage-2)
lr_ratios['levels.0.downsample'] = lr_ratios['levels.1.blocks.0.']
lr_ratios['levels.0.norm'] = lr_ratios['levels.1.blocks.0.']
# levels.1.downsample (between stage-2 and stage-3)
lr_ratios['levels.1.downsample'] = lr_ratios['levels.2.blocks.0.']
lr_ratios['levels.1.norm'] = lr_ratios['levels.2.blocks.0.']
# levels.2.downsample (between stage-3 and stage-4)
lr_ratios['levels.2.downsample'] = lr_ratios['levels.3.blocks.0.']
lr_ratios['levels.2.norm'] = lr_ratios['levels.3.blocks.0.']
return lr_ratios
def forward_features_seq_out(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
seq_out = []
for level in self.levels:
x, x_ = level(x, return_wo_downsample=True)
seq_out.append(x_)
return seq_out
def forward_features(self, x):
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
hidden_states = [x1, x2, x3, x4]
if self.num_classes > 0:
x = self.conv_head(x4)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return {
'hidden_states': hidden_states,
'pooler_output': x if self.num_classes > 0 else None
}
def forward_clip_projector(self, x): # for InternImage-H/G
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
hidden_states = [x1, x2, x3, x4]
if self.num_classes > 0:
x4 = self.dcnv3_head_x4(x4)
x = x4
x3 = self.dcnv3_head_x3(x3)
x = x + x3
x = x.flatten(-2).transpose(1, 2).contiguous()
x = self.clip_projector(x)
x = self.fc_norm(x)
return {
'hidden_states': hidden_states,
'pooler_output': x if self.num_classes > 0 else None
}
def forward(self, x):
if self.use_clip_projector: # for InternImage-H/G
outputs = self.forward_clip_projector(x)
else: # for InternImage-T/S/B/L/XL
outputs = self.forward_features(x)
hidden_states = outputs['hidden_states']
pooler_output = outputs['pooler_output']
if self.num_classes > 0:
logits = self.head(pooler_output)
else:
logits = None
return BackboneOutput(
hidden_states=hidden_states,
last_hidden_state=hidden_states[-1],
pooler_output=pooler_output,
logits=logits
)
class InternImageModel(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=0,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
res_post_norm=config.res_post_norm, # for InternImage-H/G
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
class InternImageModelForImageClassification(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
res_post_norm=config.res_post_norm, # for InternImage-H/G
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
if labels is not None:
logits = outputs['logits']
loss = F.cross_entropy(logits, labels)
outputs['loss'] = loss
return outputs
{
"crop_size": 384,
"do_center_crop": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.485,
0.456,
0.406
],
"image_std": [
0.229,
0.224,
0.225
],
"resample": 3,
"size": 384
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment