Commit f37f9c2a authored by zhe chen's avatar zhe chen
Browse files

Release code for iNaturalist 2018 (#197)

parent cfd24625
...@@ -167,6 +167,27 @@ We use standard ImageNet dataset, you can download it from http://image-net.org/ ...@@ -167,6 +167,27 @@ We use standard ImageNet dataset, you can download it from http://image-net.org/
</details> </details>
<details>
<summary>iNaturalist 2018</summary>
- For the iNaturalist 2018, please download the dataset from the [official repository](https://github.com/visipedia/inat_comp/blob/master/2018/README.md).
The file structure should look like:
```bash
$ tree inat2018/
inat2018/
├── categories.json
├── test2018
├── test2018.json
├── train2018.json
├── train2018_locations.json
├── val2018
├── val2018.json
└── val2018_locations.json
```
</details>
## Released Models ## Released Models
<details open> <details open>
...@@ -204,6 +225,19 @@ We use standard ImageNet dataset, you can download it from http://image-net.org/ ...@@ -204,6 +225,19 @@ We use standard ImageNet dataset, you can download it from http://image-net.org/
</details> </details>
<details open>
<summary> iNaturalist 2018 Image Classification </summary>
<br>
<div>
| name | pretrain | resolution | acc@1 | #param | download |
| :-----------: | :--------: | :--------: | :---: | :----: | :-----------------------------------------------------------------------------: |
| InternImage-H | Joint 427M | 384x384 | 92.6 | 1.1B | [ckpt](<>) \| [cfg](configs/inaturalist2018/internimage_h_22ktoinat18_384.yaml) |
</div>
</details>
## Evaluation ## Evaluation
To evaluate a pretrained `InternImage` on ImageNet val, run: To evaluate a pretrained `InternImage` on ImageNet val, run:
......
DATA: DATA:
IMG_SIZE: 384 IMG_SIZE: 384
DATASET: inat18
IMG_ON_MEMORY: False IMG_ON_MEMORY: False
DATA_PATH: "data/inat2018/" DATASET: inat18
AUG: AUG:
MIXUP: 0.0 MIXUP: 0.0
CUTMIX: 0.0 CUTMIX: 0.0
REPROB: 0.0
MODEL: MODEL:
PRETRAINED: './pretrained/internimage_h_jointto22k_384.pth' TYPE: intern_image_meta_former
TYPE: intern_image_with_meta DROP_PATH_RATE: 0.6
DROP_PATH_RATE: 0.2
LABEL_SMOOTHING: 0.3 LABEL_SMOOTHING: 0.3
INTERN_IMAGE: INTERN_IMAGE:
CORE_OP: 'DCNv3' CORE_OP: 'DCNv3'
...@@ -26,22 +25,22 @@ MODEL: ...@@ -26,22 +25,22 @@ MODEL:
LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29]
CENTER_FEATURE_SCALE: True CENTER_FEATURE_SCALE: True
USE_CLIP_PROJECTOR: True USE_CLIP_PROJECTOR: True
PRETRAINED: 'pretrained/internimage_h_jointto22k_384.pth'
TRAIN: TRAIN:
EMA: EMA:
ENABLE: false ENABLE: true
DECAY: 0.9998 DECAY: 0.9999
EPOCHS: 100 EPOCHS: 100
WARMUP_EPOCHS: 0 WARMUP_EPOCHS: 0
WEIGHT_DECAY: 1e-8 WEIGHT_DECAY: 0.05
BASE_LR: 3e-05 # 512 BASE_LR: 2e-05 # 512
WARMUP_LR: 3e-08 WARMUP_LR: .0
MIN_LR: 3e-07 MIN_LR: .0
LR_LAYER_DECAY: true LR_LAYER_DECAY: true
LR_LAYER_DECAY_RATIO: 0.8 LR_LAYER_DECAY_RATIO: 0.9
USE_CHECKPOINT: true
RAND_INIT_FT_HEAD: true RAND_INIT_FT_HEAD: true
USE_CHECKPOINT: false
OPTIMIZER: OPTIMIZER:
USE_ZERO: True
DCN_LR_MUL: 0.1 DCN_LR_MUL: 0.1
AMP_OPT_LEVEL: O0 AMP_OPT_LEVEL: O0
EVAL_FREQ: 1 EVAL_FREQ: 1
...@@ -12,7 +12,9 @@ import torch.distributed as dist ...@@ -12,7 +12,9 @@ import torch.distributed as dist
from timm.data import Mixup, create_transform from timm.data import Mixup, create_transform
from torchvision import transforms from torchvision import transforms
from .cached_image_folder import CachedImageFolder, ImageCephDataset from .cached_image_folder import (CachedImageFolder, ImageCephDataset,
INat18ImageCephDataset,
INat18ParserCephImage)
from .samplers import NodeDistributedSampler, SubsetRandomSampler from .samplers import NodeDistributedSampler, SubsetRandomSampler
try: try:
...@@ -229,6 +231,15 @@ def build_dataset(split, config): ...@@ -229,6 +231,15 @@ def build_dataset(split, config):
root = os.path.join(config.DATA.DATA_PATH, 'val') root = os.path.join(config.DATA.DATA_PATH, 'val')
dataset = ImageCephDataset(root, 'val', transform=transform) dataset = ImageCephDataset(root, 'val', transform=transform)
nb_classes = 1000 nb_classes = 1000
elif config.DATA.DATASET == 'inat18':
if prefix == 'train' and not config.EVAL_MODE:
root = config.DATA.DATA_PATH
dataset = INat18ImageCephDataset(
root, 'train', transform=transform, on_memory=config.DATA.IMG_ON_MEMORY)
elif prefix == 'val':
root = config.DATA.DATA_PATH
dataset = INat18ImageCephDataset(root, 'val', transform=transform)
nb_classes = 8142
else: else:
raise NotImplementedError( raise NotImplementedError(
f'build_dataset does support {config.DATA.DATASET}') f'build_dataset does support {config.DATA.DATASET}')
......
...@@ -340,6 +340,55 @@ class ImageCephDataset(data.Dataset): ...@@ -340,6 +340,55 @@ class ImageCephDataset(data.Dataset):
return self.parser.filenames(basename, absolute) return self.parser.filenames(basename, absolute)
class INat18ImageCephDataset(data.Dataset):
def __init__(self,
root,
split,
parser=None,
transform=None,
target_transform=None,
on_memory=False):
if split == 'train':
annotation_root = osp.join(root, 'train2018.json')
elif split == 'val':
annotation_root = osp.join(root, 'val2018.json')
elif split == 'test':
annotation_root = osp.join(root, 'test2018.json')
if parser is None or isinstance(parser, str):
parser = INat18ParserCephImage(root=root,
split=split,
annotation_root=annotation_root,
on_memory=on_memory)
self.parser = parser
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, temporal_info, spatial_info, target = self.parser[index]
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
temporal_info = torch.tensor(temporal_info).to(torch.float32)
spatial_info = torch.tensor(spatial_info).to(torch.float32)
return [img, temporal_info, spatial_info], target
def __len__(self):
return len(self.parser)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class Parser: class Parser:
def __init__(self): def __init__(self):
...@@ -372,7 +421,7 @@ class ParserCephImage(Parser): ...@@ -372,7 +421,7 @@ class ParserCephImage(Parser):
self.file_client = None self.file_client = None
self.kwargs = kwargs self.kwargs = kwargs
self.root = root # dataset:s3://imagenet22k self.root = root
if '22k' in root: if '22k' in root:
self.io_backend = 'petrel' self.io_backend = 'petrel'
with open(osp.join(annotation_root, '22k_class_to_idx.json'), with open(osp.join(annotation_root, '22k_class_to_idx.json'),
...@@ -497,7 +546,7 @@ class ParserCephImage(Parser): ...@@ -497,7 +546,7 @@ class ParserCephImage(Parser):
else: else:
target = int(target) target = int(target)
except: except:
print('aaaaaaaaaaaa', filepath, target) print(filepath, target)
exit() exit()
return img, target return img, target
...@@ -512,6 +561,87 @@ class ParserCephImage(Parser): ...@@ -512,6 +561,87 @@ class ParserCephImage(Parser):
return filename return filename
class INat18ParserCephImage(Parser):
def __init__(self,
root,
split,
annotation_root,
on_memory=False,
**kwargs):
super().__init__()
self.file_client = None
self.kwargs = kwargs
self.split = split
self.root = root
self.io_backend = 'disk'
data = mmcv.load(annotation_root)
self.samples = data['annotations']
self.file_names = [each['file_name'] for each in data['images']]
self.meta_data = mmcv.load(
annotation_root.replace('2018.json', '2018_locations.json'))
self.class_to_idx = {}
for i, each in enumerate(data['categories']):
self.class_to_idx[each['id']] = i
self.on_memory = on_memory
self._consecutive_errors = 0
# TODO: support on_memory function
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
anns = self.samples[index]
filename = self.file_names[index]
img_id = anns['image_id']
target = anns['category_id']
# load meta information from json file
meta = self.meta_data[index]
date = meta['date']
latitude = meta['lat']
longitude = meta['lon']
location_uncertainty = meta['loc_uncert']
temporal_info = get_temporal_info(date, miss_hour=True)
spatial_info = get_spatial_info(latitude, longitude)
filepath = osp.join(self.root, filename)
try:
if self.on_memory:
img_bytes = self.holder[filepath]
else:
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(img_bytes)[:, :, ::-1]
except Exception as e:
_logger.warning(
f'Skipped sample (index {index}, file {filepath}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self))
else:
raise e
self._consecutive_errors = 0
img = Image.fromarray(img)
if self.class_to_idx is not None:
target = self.class_to_idx[target]
else:
target = int(target)
return img, temporal_info, spatial_info, target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename, _ = self.samples[index].split(' ')
filename = osp.join(self.root, filename)
return filename
def get_temporal_info(date, miss_hour=False): def get_temporal_info(date, miss_hour=False):
try: try:
if date: if date:
......
...@@ -74,8 +74,7 @@ def parse_option(): ...@@ -74,8 +74,7 @@ def parse_option():
type=str, type=str,
help='dataset name', help='dataset name',
default=None) default=None)
parser.add_argument('--data-path', type=str, help='path to dataset', parser.add_argument('--data-path', type=str, help='path to dataset')
default='data/imagenet')
parser.add_argument('--zip', parser.add_argument('--zip',
action='store_true', action='store_true',
help='use zipped dataset instead of folder dataset') help='use zipped dataset instead of folder dataset')
...@@ -146,7 +145,10 @@ def throughput(data_loader, model, logger): ...@@ -146,7 +145,10 @@ def throughput(data_loader, model, logger):
model.eval() model.eval()
for idx, (images, _) in enumerate(data_loader): for idx, (images, _) in enumerate(data_loader):
images = images.cuda(non_blocking=True) if type(images) == list:
images = [item.cuda(non_blocking=True) for item in images]
else:
images = images.cuda(non_blocking=True)
batch_size = images.shape[0] batch_size = images.shape[0]
for i in range(50): for i in range(50):
model(images) model(images)
...@@ -403,7 +405,10 @@ def train_one_epoch(config, ...@@ -403,7 +405,10 @@ def train_one_epoch(config,
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16 amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16
for idx, (samples, targets) in enumerate(data_loader): for idx, (samples, targets) in enumerate(data_loader):
iter_begin_time = time.time() iter_begin_time = time.time()
samples = samples.cuda(non_blocking=True) if type(samples) == list:
samples = [item.cuda(non_blocking=True) for item in samples]
else:
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True)
if mixup_fn is not None: if mixup_fn is not None:
...@@ -528,7 +533,10 @@ def validate(config, data_loader, model, epoch=None): ...@@ -528,7 +533,10 @@ def validate(config, data_loader, model, epoch=None):
end = time.time() end = time.time()
for idx, (images, target) in enumerate(data_loader): for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True) if type(images) == list:
images = [item.cuda(non_blocking=True) for item in images]
else:
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True) target = target.cuda(non_blocking=True)
output = model(images) output = model(images)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
from .intern_image import InternImage from .intern_image import InternImage
from .intern_image_meta_former import InternImageMetaFormer
def build_model(config): def build_model(config):
...@@ -30,6 +31,27 @@ def build_model(config): ...@@ -30,6 +31,27 @@ def build_model(config):
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER, remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
) )
elif model_type == 'intern_image_meta_former':
model = InternImageMetaFormer(
core_op=config.MODEL.INTERN_IMAGE.CORE_OP,
num_classes=config.MODEL.NUM_CLASSES,
channels=config.MODEL.INTERN_IMAGE.CHANNELS,
depths=config.MODEL.INTERN_IMAGE.DEPTHS,
groups=config.MODEL.INTERN_IMAGE.GROUPS,
layer_scale=config.MODEL.INTERN_IMAGE.LAYER_SCALE,
offset_scale=config.MODEL.INTERN_IMAGE.OFFSET_SCALE,
post_norm=config.MODEL.INTERN_IMAGE.POST_NORM,
mlp_ratio=config.MODEL.INTERN_IMAGE.MLP_RATIO,
with_cp=config.TRAIN.USE_CHECKPOINT,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
res_post_norm=config.MODEL.INTERN_IMAGE.RES_POST_NORM, # for InternImage-H/G
dw_kernel_size=config.MODEL.INTERN_IMAGE.DW_KERNEL_SIZE, # for InternImage-H/G
use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G
level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
)
else: else:
raise NotImplementedError(f'Unkown model: {model_type}') raise NotImplementedError(f'Unkown model: {model_type}')
......
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ops_dcnv3 import modules as opsm
from timm.models.layers import DropPath, trunc_normal_
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class MetaEncoder(nn.Module):
def __init__(self, linear_size):
super(MetaEncoder, self).__init__()
self.l_size = linear_size
self.nonlin1 = nn.ReLU(inplace=True)
self.nonlin2 = nn.ReLU(inplace=True)
self.norm_fn1 = nn.LayerNorm(self.l_size)
self.norm_fn2 = nn.LayerNorm(self.l_size)
self.w1 = nn.Linear(self.l_size, self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
def forward(self, x):
y = self.w1(x)
y = self.nonlin1(y)
y = self.norm_fn1(y)
y = self.w2(y)
y = self.nonlin2(y)
y = self.norm_fn2(y)
out = x + y
return out
class CrossAttention(nn.Module):
r""" Cross Attention Module
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
attn_head_dim (int, optional): Dimension of attention head.
out_dim (int, optional): Dimension of output.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None,
out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
r"""Attentive Block
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float, optional): Dropout rate. Default: 0.0.
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
drop_path (float | tuple[float], optional): Stochastic depth rate.
Default: 0.0.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
attn_head_dim (int, optional): Dimension of attention head. Default: None.
out_dim (int, optional): Dimension of output. Default: None.
"""
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer='LN',
attn_head_dim=None,
out_dim=None):
super().__init__()
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
self.cross_dcn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
out_dim=out_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self,
x_q,
x_kv,
pos_q,
pos_k,
bool_masked_pos,
rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_dcn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv = x
pos_q, pos_k = 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k,
bool_masked_pos=None,
rel_pos_bias=None)
x = x.squeeze(1)
return x
class StemLayer(nn.Module):
r""" Stem layer of InternImage
Args:
in_chans (int): number of input channels
out_chans (int): number of output channels
act_layer (str): activation layer
norm_layer (str): normalization layer
"""
def __init__(self,
in_chans=3,
out_chans=96,
act_layer='GELU',
norm_layer='BN'):
super().__init__()
self.conv1 = nn.Conv2d(in_chans,
out_chans // 2,
kernel_size=3,
stride=2,
padding=1)
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
'channels_first', 'channels_first')
self.act = build_act_layer(act_layer)
self.conv2 = nn.Conv2d(out_chans // 2,
out_chans,
kernel_size=3,
stride=2,
padding=1)
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
'channels_last')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class DownsampleLayer(nn.Module):
r""" Downsample layer of InternImage
Args:
channels (int): number of input channels
norm_layer (str): normalization layer
"""
def __init__(self, channels, norm_layer='LN'):
super().__init__()
self.conv = nn.Conv2d(channels,
2 * channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm = build_norm_layer(2 * channels, norm_layer,
'channels_first', 'channels_last')
def forward(self, x):
x = self.conv(x.permute(0, 3, 1, 2))
x = self.norm(x)
return x
class MLPLayer(nn.Module):
r""" MLP layer of InternImage
Args:
in_features (int): number of input features
hidden_features (int): number of hidden features
out_features (int): number of output features
act_layer (str): activation layer
drop (float): dropout rate
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer='GELU',
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_act_layer(act_layer)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InternImageLayer(nn.Module):
r""" Basic layer of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
groups,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
layer_scale=None,
offset_scale=1.0,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.groups = groups
self.mlp_ratio = mlp_ratio
self.with_cp = with_cp
self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm
self.dcn = core_op(
channels=channels,
kernel_size=3,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
self.mlp = MLPLayer(in_features=channels,
hidden_features=int(channels * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.layer_scale = layer_scale is not None
if self.layer_scale:
self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.res_post_norm = res_post_norm
if res_post_norm:
self.res_post_norm1 = build_norm_layer(channels, 'LN')
self.res_post_norm2 = build_norm_layer(channels, 'LN')
def forward(self, x):
def _inner_forward(x):
if not self.layer_scale:
if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm: # for InternImage-H/G
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else:
x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.post_norm:
x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = checkpoint.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class InternImageBlock(nn.Module):
r""" Block of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
depths (list): Depth of each block.
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
depth,
groups,
downsample=True,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
offset_scale=1.0,
layer_scale=None,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
):
super().__init__()
self.channels = channels
self.depth = depth
self.post_norm = post_norm
self.center_feature_scale = center_feature_scale
self.blocks = nn.ModuleList([
InternImageLayer(
core_op=core_op,
channels=channels,
groups=groups,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center = remove_center, # for InternImage-H/G
) for i in range(depth)
])
if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None: # for InternImage-H/G
self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
)
self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False):
for i, blk in enumerate(self.blocks):
x = blk(x)
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale:
x = self.norm(x)
if return_wo_downsample:
x_ = x
if self.downsample is not None:
x = self.downsample(x)
if return_wo_downsample:
return x, x_
return x
class InternImageMetaFormer(nn.Module):
r""" InternImage
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
https://arxiv.org/pdf/2103.14030
Args:
core_op (str): Core operator. Default: 'DCNv3'
channels (int): Number of the first stage. Default: 64
depths (list): Depth of each block. Default: [3, 4, 18, 5]
groups (list): Groups of each block. Default: [3, 6, 12, 24]
num_classes (int): Number of classes. Default: 1000
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
drop_rate (float): Probability of an element to be zeroed. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.
act_layer (str): Activation layer. Default: 'GELU'
norm_layer (str): Normalization layer. Default: 'LN'
layer_scale (bool): Whether to use layer scale. Default: False
cls_scale (bool): Whether to use class scale. Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
dw_kernel_size (int): Size of the dwconv. Default: None
use_clip_projector (bool): Whether to use clip projector. Default: False
level2_post_norm (bool): Whether to use level2 post norm. Default: False
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
res_post_norm (bool): Whether to use res post norm. Default: False
center_feature_scale (bool): Whether to use center feature scale. Default: False
"""
def __init__(self,
core_op='DCNv3',
channels=64,
depths=[3, 4, 18, 5],
groups=[3, 6, 12, 24],
num_classes=1000,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.2,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
layer_scale=None,
offset_scale=1.0,
post_norm=False,
cls_scale=1.5,
with_cp=False,
dw_kernel_size=None, # for InternImage-H/G
use_clip_projector=False, # for InternImage-H/G
level2_post_norm=False, # for InternImage-H/G
level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G
remove_center=False, # for InternImage-H/G
meta_dim=64, # for metaformer
**kwargs):
super().__init__()
self.core_op = core_op
self.num_classes = num_classes
self.num_levels = len(depths)
self.depths = depths
self.channels = channels
self.num_features = int(channels * 2**(self.num_levels - 1))
self.post_norm = post_norm
self.mlp_ratio = mlp_ratio
self.use_clip_projector = use_clip_projector
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.remove_center = remove_center
print(f'using core type: {core_op}')
print(f'using activation layer: {act_layer}')
print(f'using main norm layer: {norm_layer}')
print(f'using dpr: {drop_path_type}, {drop_path_rate}')
print(f'level2_post_norm: {level2_post_norm}')
print(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}')
print(f'res_post_norm: {res_post_norm}')
print(f'remove_center: {remove_center}')
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
out_chans=channels,
act_layer=act_layer,
norm_layer=norm_layer)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
if drop_path_type == 'uniform':
for i in range(len(dpr)):
dpr[i] = drop_path_rate
self.levels = nn.ModuleList()
for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G
level = InternImageBlock(
core_op=getattr(opsm, core_op),
channels=int(channels * 2**i),
depth=depths[i],
groups=groups[i],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale, # for InternImage-H/G
remove_center=remove_center, # for InternImage-H/G
)
self.levels.append(level)
self.meta_head_1 = nn.Sequential(
nn.Linear(4, meta_dim),
nn.ReLU(inplace=True),
nn.LayerNorm(meta_dim),
MetaEncoder(meta_dim),
)
self.meta_head_2 = nn.Sequential(
nn.Linear(3, meta_dim),
nn.ReLU(inplace=True),
nn.LayerNorm(meta_dim),
MetaEncoder(meta_dim),
)
self.meta_norm = nn.LayerNorm(meta_dim * 2, eps=1e-6)
self.meta_head = nn.Linear(meta_dim * 2, num_classes) if num_classes > 0 else nn.Identity()
if not use_clip_projector: # for InternImage-T/S/B/L/XL
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
else: # for InternImage-H/G
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768
self.dcnv3_head_x4 = nn.Sequential(
nn.Conv2d(in_channels=self.num_features,
out_channels=pretrain_embed_dim * (_stride ** 2),
kernel_size=1), nn.PixelShuffle(_stride))
self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2,
out_channels=pretrain_embed_dim,
kernel_size=1)
self.clip_projector = AttentionPoolingBlock(
dim=pretrain_embed_dim,
num_heads=attnpool_num_heads,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
norm_layer=norm_layer,
out_dim=clip_embed_dim)
self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6)
self.head = nn.Linear(
clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m):
if isinstance(m, getattr(opsm, self.core_op)):
m._reset_parameters()
@torch.jit.ignore
def lr_decay_keywords(self, decay_ratio=0.87):
lr_ratios = {}
# blocks
idx = 0
for i in range(4):
layer_num = 3 - i # 3 2 1 0
for j in range(self.depths[layer_num]):
block_num = self.depths[layer_num] - j - 1
tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
decay = 1.0 * (decay_ratio ** idx)
lr_ratios[tag] = decay
idx += 1
# patch_embed (before stage-1)
lr_ratios['patch_embed'] = lr_ratios['levels.0.blocks.0.']
# levels.0.downsample (between stage-1 and stage-2)
lr_ratios['levels.0.downsample'] = lr_ratios['levels.1.blocks.0.']
lr_ratios['levels.0.norm'] = lr_ratios['levels.1.blocks.0.']
# levels.1.downsample (between stage-2 and stage-3)
lr_ratios['levels.1.downsample'] = lr_ratios['levels.2.blocks.0.']
lr_ratios['levels.1.norm'] = lr_ratios['levels.2.blocks.0.']
# levels.2.downsample (between stage-3 and stage-4)
lr_ratios['levels.2.downsample'] = lr_ratios['levels.3.blocks.0.']
lr_ratios['levels.2.norm'] = lr_ratios['levels.3.blocks.0.']
return lr_ratios
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for level in self.levels:
x = level(x)
x = self.conv_head(x.permute(0, 3, 1, 2))
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def forward_features_seq_out(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
seq_out = []
for level in self.levels:
x, x_ = level(x, return_wo_downsample=True)
seq_out.append(x_)
return seq_out
def forward_clip_projector(self, x): # for InternImage-H/G
xs = self.forward_features_seq_out(x)
x1, x2, x3, x4 = xs
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
x4 = self.dcnv3_head_x4(x4)
x = x4
x3 = self.dcnv3_head_x3(x3)
x = x + x3
x = x.flatten(-2).transpose(1, 2).contiguous()
x = self.clip_projector(x)
x = self.fc_norm(x)
return x
def forward(self, x):
x, temporal_info, spatial_info = x
temporal_info = self.meta_head_1(temporal_info)
spatial_info = self.meta_head_2(spatial_info)
meta = torch.cat([temporal_info, spatial_info], dim=-1)
if self.use_clip_projector: # for InternImage-H/G
x = self.forward_clip_projector(x)
else: # for InternImage-T/S/B/L/XL
x = self.forward_features(x)
x = self.head(x)
meta = self.meta_norm(meta)
meta = self.meta_head(meta)
x = x + meta
return x
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-12}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
--quotatype=reserved \
${SRUN_ARGS} \
python -u main.py \
--cfg ${CONFIG} \
--accumulation-steps 1 \
--local-rank 0 \
--data-path data/inat2018 \
--output work_dirs ${@:4}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment