Commit 82295dbf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable black for mobile-vision

Summary:
https://fb.workplace.com/groups/pythonfoundation/posts/2990917737888352

Remove `mobile-vision` from opt-out list; leaving `mobile-vision/SNPE` opted out because of 3rd-party code.

arc lint --take BLACK --apply-patches --paths-cmd 'hg files mobile-vision'

allow-large-files

Reviewed By: sstsai-adl

Differential Revision: D30721093

fbshipit-source-id: 9e5c16d988b315b93a28038443ecfb92efd18ef8
parent a56c7e15
# code adapt from https://www.internalfb.com/intern/diffusion/FBS/browse/master/fbcode/mobile-vision/experimental/deit/models.py # code adapt from https://www.internalfb.com/intern/diffusion/FBS/browse/master/fbcode/mobile-vision/experimental/deit/models.py
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
import math
import json import json
import math
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from aml.multimodal_video.utils.einops.lib import rearrange
from detectron2.modeling import Backbone, BACKBONE_REGISTRY from detectron2.modeling import Backbone, BACKBONE_REGISTRY
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from aml.multimodal_video.utils.einops.lib import rearrange
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
def monkey_patch_forward(self, x): def monkey_patch_forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x).flatten(2).transpose(1, 2)
return x return x
PatchEmbed.forward = monkey_patch_forward PatchEmbed.forward = monkey_patch_forward
class DistilledVisionTransformer(VisionTransformer, Backbone): class DistilledVisionTransformer(VisionTransformer, Backbone):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.head_dist = (
nn.Linear(self.embed_dim, self.num_classes)
if self.num_classes > 0
else nn.Identity()
)
trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.dist_token, std=0.02)
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=0.02)
self.head_dist.apply(self._init_weights) self.head_dist.apply(self._init_weights)
self.norm = None self.norm = None
...@@ -48,10 +54,13 @@ class DistilledVisionTransformer(VisionTransformer, Backbone): ...@@ -48,10 +54,13 @@ class DistilledVisionTransformer(VisionTransformer, Backbone):
pos_tokens = pos_tokens.transpose(1, 2).reshape(-1, embed_size, H0, W0) pos_tokens = pos_tokens.transpose(1, 2).reshape(-1, embed_size, H0, W0)
# interp # interp
pos_tokens = F.interpolate( pos_tokens = F.interpolate(
pos_tokens, size=(H, W), mode="bilinear", align_corners=False, pos_tokens,
size=(H, W),
mode="bilinear",
align_corners=False,
) )
# flatten and reshape back # flatten and reshape back
pos_tokens = pos_tokens.reshape(-1, embed_size, H*W).transpose(1, 2) pos_tokens = pos_tokens.reshape(-1, embed_size, H * W).transpose(1, 2)
pos_embed = torch.cat((self.pos_embed[:, :2, :], pos_tokens), dim=1) pos_embed = torch.cat((self.pos_embed[:, :2, :], pos_tokens), dim=1)
return pos_embed return pos_embed
...@@ -65,7 +74,9 @@ class DistilledVisionTransformer(VisionTransformer, Backbone): ...@@ -65,7 +74,9 @@ class DistilledVisionTransformer(VisionTransformer, Backbone):
B = x.shape[0] B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1) dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1) x = torch.cat((cls_tokens, dist_token, x), dim=1)
...@@ -77,8 +88,8 @@ class DistilledVisionTransformer(VisionTransformer, Backbone): ...@@ -77,8 +88,8 @@ class DistilledVisionTransformer(VisionTransformer, Backbone):
for blk in self.blocks: for blk in self.blocks:
x = blk(x) x = blk(x)
#x = self.norm(x) # x = self.norm(x)
spatial = rearrange(x[:, 2:], 'b (h w) c -> b c h w', h=H, w=W) spatial = rearrange(x[:, 2:], "b (h w) c -> b c h w", h=H, w=W)
return x[:, 0], x[:, 1], spatial return x[:, 0], x[:, 1], spatial
def forward(self, x): def forward(self, x):
...@@ -92,16 +103,23 @@ class DistilledVisionTransformer(VisionTransformer, Backbone): ...@@ -92,16 +103,23 @@ class DistilledVisionTransformer(VisionTransformer, Backbone):
# # during inference, return the average of both classifier predictions # # during inference, return the average of both classifier predictions
# return (x + x_dist) / 2 # return (x + x_dist) / 2
def _cfg(input_size=224, url='', **kwargs):
def _cfg(input_size=224, url="", **kwargs):
return { return {
'url': url, "url": url,
'num_classes': 1000, 'input_size': (3, input_size, input_size), 'pool_size': None, "num_classes": 1000,
'crop_pct': .9, 'interpolation': 'bilinear', "input_size": (3, input_size, input_size),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, "pool_size": None,
'first_conv': 'patch_embed.proj', 'classifier': 'head', "crop_pct": 0.9,
**kwargs "interpolation": "bilinear",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
"first_conv": "patch_embed.proj",
"classifier": "head",
**kwargs,
} }
def deit_scalable_distilled(model_config, pretrained=False, **kwargs): def deit_scalable_distilled(model_config, pretrained=False, **kwargs):
assert not pretrained assert not pretrained
model = DistilledVisionTransformer( model = DistilledVisionTransformer(
...@@ -120,19 +138,21 @@ def deit_scalable_distilled(model_config, pretrained=False, **kwargs): ...@@ -120,19 +138,21 @@ def deit_scalable_distilled(model_config, pretrained=False, **kwargs):
print("model train config: {}".format(model.default_cfg)) print("model train config: {}".format(model.default_cfg))
return model return model
def add_deit_backbone_config(cfg): def add_deit_backbone_config(cfg):
cfg.MODEL.DEIT = type(cfg)() cfg.MODEL.DEIT = type(cfg)()
cfg.MODEL.DEIT.MODEL_CONFIG = None cfg.MODEL.DEIT.MODEL_CONFIG = None
cfg.MODEL.DEIT.WEIGHTS = None cfg.MODEL.DEIT.WEIGHTS = None
@BACKBONE_REGISTRY.register() @BACKBONE_REGISTRY.register()
def deit_d2go_model_wrapper(cfg, _): def deit_d2go_model_wrapper(cfg, _):
assert cfg.MODEL.DEIT.MODEL_CONFIG is not None assert cfg.MODEL.DEIT.MODEL_CONFIG is not None
with PathManager.open(cfg.MODEL.DEIT.MODEL_CONFIG) as f: with PathManager.open(cfg.MODEL.DEIT.MODEL_CONFIG) as f:
model_config = json.load(f) model_config = json.load(f)
model = deit_scalable_distilled( model = deit_scalable_distilled(
model_config, model_config,
num_classes=0, # set num_classes=0 to avoid building cls head num_classes=0, # set num_classes=0 to avoid building cls head
drop_rate=0, drop_rate=0,
drop_path_rate=0.1, drop_path_rate=0.1,
) )
......
...@@ -4,23 +4,31 @@ ...@@ -4,23 +4,31 @@
# Apache License v2.0 # Apache License v2.0
import json import json
import torch
from aml.multimodal_video.utils.einops.lib import rearrange
from torch import nn
import torch.nn.functional as F
import math import math
from functools import partial
import torch
import torch.nn.functional as F
from aml.multimodal_video.utils.einops.lib import rearrange
from detectron2.modeling import Backbone, BACKBONE_REGISTRY from detectron2.modeling import Backbone, BACKBONE_REGISTRY
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from functools import partial
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block as transformer_block
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.models.vision_transformer import Block as transformer_block
from torch import nn
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, base_dim, depth, heads, mlp_ratio, def __init__(
drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): self,
base_dim,
depth,
heads,
mlp_ratio,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_prob=None,
):
super(Transformer, self).__init__() super(Transformer, self).__init__()
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
embed_dim = base_dim * heads embed_dim = base_dim * heads
...@@ -28,22 +36,25 @@ class Transformer(nn.Module): ...@@ -28,22 +36,25 @@ class Transformer(nn.Module):
if drop_path_prob is None: if drop_path_prob is None:
drop_path_prob = [0.0 for _ in range(depth)] drop_path_prob = [0.0 for _ in range(depth)]
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList(
transformer_block( [
dim=embed_dim, transformer_block(
num_heads=heads, dim=embed_dim,
mlp_ratio=mlp_ratio, num_heads=heads,
qkv_bias=True, mlp_ratio=mlp_ratio,
drop=drop_rate, qkv_bias=True,
attn_drop=attn_drop_rate, drop=drop_rate,
drop_path=drop_path_prob[i], attn_drop=attn_drop_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6) drop_path=drop_path_prob[i],
) norm_layer=partial(nn.LayerNorm, eps=1e-6),
for i in range(depth)]) )
for i in range(depth)
]
)
def forward(self, x, cls_tokens): def forward(self, x, cls_tokens):
h, w = x.shape[2:4] h, w = x.shape[2:4]
x = rearrange(x, 'b c h w -> b (h w) c') x = rearrange(x, "b c h w -> b (h w) c")
token_length = cls_tokens.shape[1] token_length = cls_tokens.shape[1]
x = torch.cat((cls_tokens, x), dim=1) x = torch.cat((cls_tokens, x), dim=1)
...@@ -52,23 +63,37 @@ class Transformer(nn.Module): ...@@ -52,23 +63,37 @@ class Transformer(nn.Module):
cls_tokens = x[:, :token_length] cls_tokens = x[:, :token_length]
x = x[:, token_length:] x = x[:, token_length:]
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
return x, cls_tokens return x, cls_tokens
class conv_head_pooling(nn.Module): class conv_head_pooling(nn.Module):
def __init__(self, in_feature, out_feature, stride, conv_type, def __init__(
padding_mode='zeros', dilation=1): self,
in_feature,
out_feature,
stride,
conv_type,
padding_mode="zeros",
dilation=1,
):
super(conv_head_pooling, self).__init__() super(conv_head_pooling, self).__init__()
if conv_type=="depthwise": if conv_type == "depthwise":
_groups = in_feature _groups = in_feature
else: else:
_groups = 1 _groups = 1
print("_groups in conv_head_pooling: ", _groups) print("_groups in conv_head_pooling: ", _groups)
self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=3, self.conv = nn.Conv2d(
padding=dilation, dilation=dilation, stride=stride, in_feature,
padding_mode=padding_mode, groups=_groups) out_feature,
kernel_size=3,
padding=dilation,
dilation=dilation,
stride=stride,
padding_mode=padding_mode,
groups=_groups,
)
self.fc = nn.Linear(in_feature, out_feature) self.fc = nn.Linear(in_feature, out_feature)
def forward(self, x, cls_token): def forward(self, x, cls_token):
...@@ -80,11 +105,16 @@ class conv_head_pooling(nn.Module): ...@@ -80,11 +105,16 @@ class conv_head_pooling(nn.Module):
class conv_embedding(nn.Module): class conv_embedding(nn.Module):
def __init__(self, in_channels, out_channels, patch_size, def __init__(self, in_channels, out_channels, patch_size, stride, padding):
stride, padding):
super(conv_embedding, self).__init__() super(conv_embedding, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, self.conv = nn.Conv2d(
stride=stride, padding=padding, bias=True) in_channels,
out_channels,
kernel_size=patch_size,
stride=stride,
padding=padding,
bias=True,
)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -92,10 +122,23 @@ class conv_embedding(nn.Module): ...@@ -92,10 +122,23 @@ class conv_embedding(nn.Module):
class PoolingTransformer(Backbone): class PoolingTransformer(Backbone):
def __init__(self, image_size, patch_size, stride, base_dims, depth, heads, def __init__(
mlp_ratio, conv_type="depthwise", num_classes=1000, in_chans=3, self,
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0, image_size,
dilated=False): patch_size,
stride,
base_dims,
depth,
heads,
mlp_ratio,
conv_type="depthwise",
num_classes=1000,
in_chans=3,
attn_drop_rate=0.0,
drop_rate=0.0,
drop_path_rate=0.0,
dilated=False,
):
super(PoolingTransformer, self).__init__() super(PoolingTransformer, self).__init__()
total_block = sum(depth) total_block = sum(depth)
...@@ -104,8 +147,7 @@ class PoolingTransformer(Backbone): ...@@ -104,8 +147,7 @@ class PoolingTransformer(Backbone):
self.padding = padding self.padding = padding
self.stride = stride self.stride = stride
width = math.floor( width = math.floor((image_size + 2 * padding - patch_size) / stride + 1)
(image_size + 2 * padding - patch_size) / stride + 1)
self.conv_type = conv_type self.conv_type = conv_type
self.base_dims = base_dims self.base_dims = base_dims
...@@ -114,15 +156,14 @@ class PoolingTransformer(Backbone): ...@@ -114,15 +156,14 @@ class PoolingTransformer(Backbone):
self.patch_size = patch_size self.patch_size = patch_size
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.randn(1, base_dims[0] * heads[0], width, width), torch.randn(1, base_dims[0] * heads[0], width, width), requires_grad=True
requires_grad=True )
self.patch_embed = conv_embedding(
in_chans, base_dims[0] * heads[0], patch_size, stride, padding
) )
self.patch_embed = conv_embedding(in_chans, base_dims[0] * heads[0],
patch_size, stride, padding)
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(
torch.randn(1, 1, base_dims[0] * heads[0]), torch.randn(1, 1, base_dims[0] * heads[0]), requires_grad=True
requires_grad=True
) )
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
...@@ -130,14 +171,22 @@ class PoolingTransformer(Backbone): ...@@ -130,14 +171,22 @@ class PoolingTransformer(Backbone):
self.pools = nn.ModuleList([]) self.pools = nn.ModuleList([])
for stage in range(len(depth)): for stage in range(len(depth)):
drop_path_prob = [drop_path_rate * i / total_block drop_path_prob = [
for i in range(block_idx, block_idx + depth[stage])] drop_path_rate * i / total_block
for i in range(block_idx, block_idx + depth[stage])
]
block_idx += depth[stage] block_idx += depth[stage]
self.transformers.append( self.transformers.append(
Transformer(base_dims[stage], depth[stage], heads[stage], Transformer(
mlp_ratio, base_dims[stage],
drop_rate, attn_drop_rate, drop_path_prob) depth[stage],
heads[stage],
mlp_ratio,
drop_rate,
attn_drop_rate,
drop_path_prob,
)
) )
if stage < len(heads) - 1: if stage < len(heads) - 1:
if stage == len(heads) - 2 and dilated: if stage == len(heads) - 2 and dilated:
...@@ -147,14 +196,16 @@ class PoolingTransformer(Backbone): ...@@ -147,14 +196,16 @@ class PoolingTransformer(Backbone):
pool_dilation = 1 pool_dilation = 1
pool_stride = 2 pool_stride = 2
self.pools.append( self.pools.append(
conv_head_pooling(base_dims[stage] * heads[stage], conv_head_pooling(
base_dims[stage + 1] * heads[stage + 1], base_dims[stage] * heads[stage],
stride=pool_stride, dilation=pool_dilation, base_dims[stage + 1] * heads[stage + 1],
conv_type=self.conv_type stride=pool_stride,
) dilation=pool_dilation,
conv_type=self.conv_type,
)
) )
#self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) # self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
self.embed_dim = base_dims[-1] * heads[-1] self.embed_dim = base_dims[-1] * heads[-1]
# Classifier head # Classifier head
...@@ -163,8 +214,8 @@ class PoolingTransformer(Backbone): ...@@ -163,8 +214,8 @@ class PoolingTransformer(Backbone):
else: else:
self.head = nn.Identity() self.head = nn.Identity()
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
...@@ -174,12 +225,12 @@ class PoolingTransformer(Backbone): ...@@ -174,12 +225,12 @@ class PoolingTransformer(Backbone):
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {"pos_embed", "cls_token"}
def get_classifier(self): def get_classifier(self):
return self.head return self.head
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes self.num_classes = num_classes
if num_classes > 0: if num_classes > 0:
self.head = nn.Linear(self.embed_dim, num_classes) self.head = nn.Linear(self.embed_dim, num_classes)
...@@ -192,7 +243,10 @@ class PoolingTransformer(Backbone): ...@@ -192,7 +243,10 @@ class PoolingTransformer(Backbone):
return self.pos_embed return self.pos_embed
# interp # interp
pos_embed = F.interpolate( pos_embed = F.interpolate(
self.pos_embed, size=(H, W), mode="bilinear", align_corners=False, self.pos_embed,
size=(H, W),
mode="bilinear",
align_corners=False,
) )
return pos_embed return pos_embed
...@@ -202,10 +256,8 @@ class PoolingTransformer(Backbone): ...@@ -202,10 +256,8 @@ class PoolingTransformer(Backbone):
x = self.patch_embed(x) x = self.patch_embed(x)
# featuremap size after patch embeding # featuremap size after patch embeding
H = math.floor( H = math.floor((H + 2 * self.padding - self.patch_size) / self.stride + 1)
(H + 2 * self.padding - self.patch_size) / self.stride + 1) W = math.floor((W + 2 * self.padding - self.patch_size) / self.stride + 1)
W = math.floor(
(W + 2 * self.padding - self.patch_size) / self.stride + 1)
pos_embed = self._get_pos_embed(H, W) pos_embed = self._get_pos_embed(H, W)
...@@ -217,7 +269,7 @@ class PoolingTransformer(Backbone): ...@@ -217,7 +269,7 @@ class PoolingTransformer(Backbone):
x, cls_tokens = self.pools[stage](x, cls_tokens) x, cls_tokens = self.pools[stage](x, cls_tokens)
x, cls_tokens = self.transformers[-1](x, cls_tokens) x, cls_tokens = self.transformers[-1](x, cls_tokens)
#cls_tokens = self.norm(cls_tokens) # no gradient for layer norm, which cause failure # cls_tokens = self.norm(cls_tokens) # no gradient for layer norm, which cause failure
return cls_tokens, x return cls_tokens, x
...@@ -231,27 +283,29 @@ class DistilledPoolingTransformer(PoolingTransformer): ...@@ -231,27 +283,29 @@ class DistilledPoolingTransformer(PoolingTransformer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(
torch.randn(1, 2, self.base_dims[0] * self.heads[0]), torch.randn(1, 2, self.base_dims[0] * self.heads[0]), requires_grad=True
requires_grad=True) )
if self.num_classes > 0: if self.num_classes > 0:
self.head_dist = nn.Linear(self.base_dims[-1] * self.heads[-1], self.head_dist = nn.Linear(
self.num_classes) self.base_dims[-1] * self.heads[-1], self.num_classes
)
else: else:
self.head_dist = nn.Identity() self.head_dist = nn.Identity()
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=0.02)
self.head_dist.apply(self._init_weights) self.head_dist.apply(self._init_weights)
def forward(self, x): def forward(self, x):
cls_token, x = self.forward_features(x) cls_token, x = self.forward_features(x)
return x return x
#x_cls = self.head(cls_token[:, 0]) # x_cls = self.head(cls_token[:, 0])
#x_dist = self.head_dist(cls_token[:, 1]) # x_dist = self.head_dist(cls_token[:, 1])
#if self.training: # if self.training:
# return x_cls, x_dist # return x_cls, x_dist
#else: # else:
# return (x_cls + x_dist) / 2 # return (x_cls + x_dist) / 2
def pit_scalable_distilled(model_config, pretrained=False, print_info=True, **kwargs): def pit_scalable_distilled(model_config, pretrained=False, print_info=True, **kwargs):
if "conv_type" in model_config: if "conv_type" in model_config:
conv_type = model_config["conv_type"] conv_type = model_config["conv_type"]
...@@ -266,13 +320,14 @@ def pit_scalable_distilled(model_config, pretrained=False, print_info=True, **kw ...@@ -266,13 +320,14 @@ def pit_scalable_distilled(model_config, pretrained=False, print_info=True, **kw
heads=model_config["h"], heads=model_config["h"],
mlp_ratio=model_config["r"], mlp_ratio=model_config["r"],
conv_type=conv_type, conv_type=conv_type,
**kwargs **kwargs,
) )
if print_info: if print_info:
print("model arch config: {}".format(model_config)) print("model arch config: {}".format(model_config))
assert pretrained == False, "pretrained must be False" assert pretrained == False, "pretrained must be False"
return model return model
def add_pit_backbone_config(cfg): def add_pit_backbone_config(cfg):
cfg.MODEL.PIT = type(cfg)() cfg.MODEL.PIT = type(cfg)()
cfg.MODEL.PIT.MODEL_CONFIG = None cfg.MODEL.PIT.MODEL_CONFIG = None
...@@ -288,7 +343,7 @@ def pit_d2go_model_wrapper(cfg, _): ...@@ -288,7 +343,7 @@ def pit_d2go_model_wrapper(cfg, _):
model_config = json.load(f) model_config = json.load(f)
model = pit_scalable_distilled( model = pit_scalable_distilled(
model_config, model_config,
num_classes=0, # set num_classes=0 to avoid building cls head num_classes=0, # set num_classes=0 to avoid building cls head
drop_rate=0, drop_rate=0,
drop_path_rate=0.1, drop_path_rate=0.1,
dilated=dilated, dilated=dilated,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .config import add_detr_config from .config import add_detr_config
from .detr import Detr
from .dataset_mapper import DetrDatasetMapper from .dataset_mapper import DetrDatasetMapper
from .detr import Detr
__all__ = ['add_detr_config', 'Detr', 'DetrDatasetMapper'] __all__ = ["add_detr_config", "Detr", "DetrDatasetMapper"]
...@@ -19,7 +19,7 @@ def add_detr_config(cfg): ...@@ -19,7 +19,7 @@ def add_detr_config(cfg):
cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk3"] cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk3"]
# For Segmentation # For Segmentation
cfg.MODEL.DETR.FROZEN_WEIGHTS = '' cfg.MODEL.DETR.FROZEN_WEIGHTS = ""
# LOSS # LOSS
cfg.MODEL.DETR.DEFORMABLE = False cfg.MODEL.DETR.DEFORMABLE = False
......
...@@ -6,7 +6,6 @@ import logging ...@@ -6,7 +6,6 @@ import logging
import numpy as np import numpy as np
import torch import torch
from detectron2.data import detection_utils as utils from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T from detectron2.data import transforms as T
...@@ -28,7 +27,9 @@ def build_transform_gen(cfg, is_train): ...@@ -28,7 +27,9 @@ def build_transform_gen(cfg, is_train):
max_size = cfg.INPUT.MAX_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice" sample_style = "choice"
if sample_style == "range": if sample_style == "range":
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) assert (
len(min_size) == 2
), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
tfm_gens = [] tfm_gens = []
...@@ -65,7 +66,9 @@ class DetrDatasetMapper: ...@@ -65,7 +66,9 @@ class DetrDatasetMapper:
self.mask_on = cfg.MODEL.MASK_ON self.mask_on = cfg.MODEL.MASK_ON
self.tfm_gens = build_transform_gen(cfg, is_train) self.tfm_gens = build_transform_gen(cfg, is_train)
logging.getLogger(__name__).info( logging.getLogger(__name__).info(
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen)) "Full TransformGens used in training: {}, crop: {}".format(
str(self.tfm_gens), str(self.crop_gen)
)
) )
self.img_format = cfg.INPUT.FORMAT self.img_format = cfg.INPUT.FORMAT
...@@ -98,7 +101,9 @@ class DetrDatasetMapper: ...@@ -98,7 +101,9 @@ class DetrDatasetMapper:
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue. # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor. # Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["image"] = torch.as_tensor(
np.ascontiguousarray(image.transpose(2, 0, 1))
)
if not self.is_train: if not self.is_train:
# USER: Modify this if you want to keep them for some reason. # USER: Modify this if you want to keep them for some reason.
......
...@@ -75,7 +75,8 @@ class ResNetMaskedBackbone(nn.Module): ...@@ -75,7 +75,8 @@ class ResNetMaskedBackbone(nn.Module):
class FBNetMaskedBackbone(ResNetMaskedBackbone): class FBNetMaskedBackbone(ResNetMaskedBackbone):
""" This is a thin wrapper around D2's backbone to provide padding masking""" """This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg): def __init__(self, cfg):
nn.Module.__init__(self) nn.Module.__init__(self)
self.backbone = build_backbone(cfg) self.backbone = build_backbone(cfg)
...@@ -102,16 +103,18 @@ class FBNetMaskedBackbone(ResNetMaskedBackbone): ...@@ -102,16 +103,18 @@ class FBNetMaskedBackbone(ResNetMaskedBackbone):
ret_features[k] = NestedTensor(features[k], masks[i]) ret_features[k] = NestedTensor(features[k], masks[i])
return ret_features return ret_features
class SimpleSingleStageBackbone(ResNetMaskedBackbone): class SimpleSingleStageBackbone(ResNetMaskedBackbone):
"""This is a simple wrapper for single stage backbone, """This is a simple wrapper for single stage backbone,
please set the required configs: please set the required configs:
cfg.MODEL.BACKBONE.SIMPLE == True, cfg.MODEL.BACKBONE.SIMPLE == True,
cfg.MODEL.BACKBONE.STRIDE, cfg.MODEL.BACKBONE.CHANNEL cfg.MODEL.BACKBONE.STRIDE, cfg.MODEL.BACKBONE.CHANNEL
""" """
def __init__(self, cfg): def __init__(self, cfg):
nn.Module.__init__(self) nn.Module.__init__(self)
self.backbone = build_backbone(cfg) self.backbone = build_backbone(cfg)
self.out_features = ['out'] self.out_features = ["out"]
assert cfg.MODEL.BACKBONE.SIMPLE is True assert cfg.MODEL.BACKBONE.SIMPLE is True
self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE] self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE]
self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL] self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL]
...@@ -165,7 +168,7 @@ class Detr(nn.Module): ...@@ -165,7 +168,7 @@ class Detr(nn.Module):
N_steps = hidden_dim // 2 N_steps = hidden_dim // 2
if "resnet" in cfg.MODEL.BACKBONE.NAME.lower(): if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg) d2_backbone = ResNetMaskedBackbone(cfg)
elif 'fbnet' in cfg.MODEL.BACKBONE.NAME.lower(): elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = FBNetMaskedBackbone(cfg) d2_backbone = FBNetMaskedBackbone(cfg)
elif cfg.MODEL.BACKBONE.SIMPLE: elif cfg.MODEL.BACKBONE.SIMPLE:
d2_backbone = SimpleSingleStageBackbone(cfg) d2_backbone = SimpleSingleStageBackbone(cfg)
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import torch.utils.data import torch.utils.data
import torchvision import torchvision
from .coco import build as build_coco
from .ade import build as build_ade from .ade import build as build_ade
from .coco import build as build_coco
def get_coco_api_from_dataset(dataset): def get_coco_api_from_dataset(dataset):
for _ in range(10): for _ in range(10):
...@@ -18,14 +19,15 @@ def get_coco_api_from_dataset(dataset): ...@@ -18,14 +19,15 @@ def get_coco_api_from_dataset(dataset):
def build_dataset(image_set, args): def build_dataset(image_set, args):
if args.dataset_file == 'coco': if args.dataset_file == "coco":
dataset = build_coco(image_set, args) dataset = build_coco(image_set, args)
elif args.dataset_file == 'coco_panoptic': elif args.dataset_file == "coco_panoptic":
# to avoid making panopticapi required for coco # to avoid making panopticapi required for coco
from .coco_panoptic import build as build_coco_panoptic from .coco_panoptic import build as build_coco_panoptic
dataset = build_coco_panoptic(image_set, args) dataset = build_coco_panoptic(image_set, args)
elif args.dataset_file == 'ade': elif args.dataset_file == "ade":
dataset = build_ade(image_set, args) dataset = build_ade(image_set, args)
else: else:
raise ValueError(f'dataset {args.dataset_file} not supported') raise ValueError(f"dataset {args.dataset_file} not supported")
return dataset return dataset
import math
import os import os
import random
import sys import sys
import numpy as np import numpy as np
import random
import math
from PIL import Image, ImageOps, ImageFilter
import skimage.morphology as morp import skimage.morphology as morp
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import torchvision import torchvision
import torchvision.transforms as transform import torchvision.transforms as transform
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from PIL import Image, ImageOps, ImageFilter
from .coco import make_coco_transforms from .coco import make_coco_transforms
class ADE20KParsing(torchvision.datasets.VisionDataset): class ADE20KParsing(torchvision.datasets.VisionDataset):
def __init__(self, root, split, transforms=None): def __init__(self, root, split, transforms=None):
super(ADE20KParsing, self).__init__( super(ADE20KParsing, self).__init__(root)
root)
# assert exists and prepare dataset automatically # assert exists and prepare dataset automatically
assert PathManager.exists(root), "Please setup the dataset" assert PathManager.exists(root), "Please setup the dataset"
self.images, self.masks = _get_ade20k_pairs(root, split) self.images, self.masks = _get_ade20k_pairs(root, split)
assert (len(self.images) == len(self.masks)) assert len(self.images) == len(self.masks)
if len(self.images) == 0: if len(self.images) == 0:
raise(RuntimeError("Found 0 images in subfolders of: \ raise (
" + root + "\n")) RuntimeError(
"Found 0 images in subfolders of: \
"
+ root
+ "\n"
)
)
self._transforms = transforms self._transforms = transforms
def _mask_transform(self, mask): def _mask_transform(self, mask):
target = np.array(mask).astype('int64') - 1 target = np.array(mask).astype("int64") - 1
return target return target
def __getitem__(self, index): def __getitem__(self, index):
with PathManager.open(self.images[index], "rb") as f: with PathManager.open(self.images[index], "rb") as f:
img = Image.open(f).convert('RGB') img = Image.open(f).convert("RGB")
with PathManager.open(self.masks[index], "rb") as f: with PathManager.open(self.masks[index], "rb") as f:
mask = Image.open(f).convert('P') mask = Image.open(f).convert("P")
w, h = img.size w, h = img.size
## generating bbox and masks ## generating bbox and masks
# get different classes # get different classes
...@@ -43,29 +49,35 @@ class ADE20KParsing(torchvision.datasets.VisionDataset): ...@@ -43,29 +49,35 @@ class ADE20KParsing(torchvision.datasets.VisionDataset):
classes = np.unique(mask) classes = np.unique(mask)
if -1 in classes: if -1 in classes:
classes = classes[1:] classes = classes[1:]
segmasks = mask == classes[:,None,None] segmasks = mask == classes[:, None, None]
# find connected component # find connected component
detr_masks = [] detr_masks = []
labels = [] labels = []
for i in range(len(classes)): for i in range(len(classes)):
mask = segmasks[i] mask = segmasks[i]
mclass = classes[i] mclass = classes[i]
connected, nslice = morp.label(mask, connectivity=2, background=0, return_num=True) connected, nslice = morp.label(
mask, connectivity=2, background=0, return_num=True
)
for j in range(1, nslice + 1): for j in range(1, nslice + 1):
detr_masks.append(connected==j) detr_masks.append(connected == j)
labels.append(mclass) labels.append(mclass)
target = {} target = {}
target['image_id'] = torch.tensor(int(os.path.basename(self.images[index])[10:-4])) target["image_id"] = torch.tensor(
int(os.path.basename(self.images[index])[10:-4])
)
if len(detr_masks) > 0: if len(detr_masks) > 0:
target['masks'] = torch.as_tensor(np.stack(detr_masks, axis=0), dtype=torch.uint8) target["masks"] = torch.as_tensor(
target['boxes'] = masks_to_boxes(target['masks']) np.stack(detr_masks, axis=0), dtype=torch.uint8
)
target["boxes"] = masks_to_boxes(target["masks"])
else: else:
target['masks'] = torch.as_tensor(detr_masks, dtype=torch.uint8) target["masks"] = torch.as_tensor(detr_masks, dtype=torch.uint8)
target['boxes'] = target['masks'] target["boxes"] = target["masks"]
target['labels'] = torch.tensor(labels) target["labels"] = torch.tensor(labels)
target['orig_size'] = torch.as_tensor([int(h), int(w)]) target["orig_size"] = torch.as_tensor([int(h), int(w)])
target['size'] = torch.as_tensor([int(h), int(w)]) target["size"] = torch.as_tensor([int(h), int(w)])
if self._transforms is not None: if self._transforms is not None:
img, target = self._transforms(img, target) img, target = self._transforms(img, target)
...@@ -78,6 +90,7 @@ class ADE20KParsing(torchvision.datasets.VisionDataset): ...@@ -78,6 +90,7 @@ class ADE20KParsing(torchvision.datasets.VisionDataset):
def pred_offset(self): def pred_offset(self):
return 1 return 1
def masks_to_boxes(masks): def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks """Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
...@@ -92,18 +105,18 @@ def masks_to_boxes(masks): ...@@ -92,18 +105,18 @@ def masks_to_boxes(masks):
x = torch.arange(0, w, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x) y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0)) x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0] x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0)) y_mask = masks * y.unsqueeze(0)
y_max = y_mask.flatten(1).max(-1)[0] y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1) return torch.stack([x_min, y_min, x_max, y_max], 1)
def _get_ade20k_pairs(folder, split='train'): def _get_ade20k_pairs(folder, split="train"):
def get_path_pairs(img_folder, mask_folder): def get_path_pairs(img_folder, mask_folder):
img_paths = [] img_paths = []
mask_paths = [] mask_paths = []
...@@ -114,33 +127,35 @@ def _get_ade20k_pairs(folder, split='train'): ...@@ -114,33 +127,35 @@ def _get_ade20k_pairs(folder, split='train'):
basename, _ = os.path.splitext(filename) basename, _ = os.path.splitext(filename)
if filename.endswith(".jpg"): if filename.endswith(".jpg"):
imgpath = os.path.join(img_folder, filename) imgpath = os.path.join(img_folder, filename)
maskname = basename + '.png' maskname = basename + ".png"
maskpath = os.path.join(mask_folder, maskname) maskpath = os.path.join(mask_folder, maskname)
img_paths.append(imgpath) img_paths.append(imgpath)
mask_paths.append(maskpath) mask_paths.append(maskpath)
#if PathManager.isfile(maskpath): # if PathManager.isfile(maskpath):
#else: # else:
# print('cannot find the mask:', maskpath) # print('cannot find the mask:', maskpath)
return img_paths, mask_paths return img_paths, mask_paths
if split == 'train': if split == "train":
img_folder = os.path.join(folder, 'images/training') img_folder = os.path.join(folder, "images/training")
mask_folder = os.path.join(folder, 'annotations/training') mask_folder = os.path.join(folder, "annotations/training")
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
print('len(img_paths):', len(img_paths)) print("len(img_paths):", len(img_paths))
assert len(img_paths) == 20210 assert len(img_paths) == 20210
elif split == 'val': elif split == "val":
img_folder = os.path.join(folder, 'images/validation') img_folder = os.path.join(folder, "images/validation")
mask_folder = os.path.join(folder, 'annotations/validation') mask_folder = os.path.join(folder, "annotations/validation")
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
assert len(img_paths) == 2000 assert len(img_paths) == 2000
else: else:
assert split == 'trainval' assert split == "trainval"
train_img_folder = os.path.join(folder, 'images/training') train_img_folder = os.path.join(folder, "images/training")
train_mask_folder = os.path.join(folder, 'annotations/training') train_mask_folder = os.path.join(folder, "annotations/training")
val_img_folder = os.path.join(folder, 'images/validation') val_img_folder = os.path.join(folder, "images/validation")
val_mask_folder = os.path.join(folder, 'annotations/validation') val_mask_folder = os.path.join(folder, "annotations/validation")
train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder) train_img_paths, train_mask_paths = get_path_pairs(
train_img_folder, train_mask_folder
)
val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder) val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
img_paths = train_img_paths + val_img_paths img_paths = train_img_paths + val_img_paths
mask_paths = train_mask_paths + val_mask_paths mask_paths = train_mask_paths + val_mask_paths
...@@ -149,5 +164,7 @@ def _get_ade20k_pairs(folder, split='train'): ...@@ -149,5 +164,7 @@ def _get_ade20k_pairs(folder, split='train'):
def build(image_set, args): def build(image_set, args):
dataset = ADE20KParsing(args.ade_path, image_set, transforms=make_coco_transforms(image_set)) dataset = ADE20KParsing(
args.ade_path, image_set, transforms=make_coco_transforms(image_set)
)
return dataset return dataset
...@@ -8,15 +8,14 @@ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references ...@@ -8,15 +8,14 @@ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references
""" """
import os import os
from pathlib import Path from pathlib import Path
from PIL import Image
import detr.datasets.transforms as T
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
from pycocotools import mask as coco_mask
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
import detr.datasets.transforms as T from PIL import Image
from pycocotools import mask as coco_mask
class CocoDetection(torchvision.datasets.CocoDetection): class CocoDetection(torchvision.datasets.CocoDetection):
...@@ -35,7 +34,7 @@ class CocoDetection(torchvision.datasets.CocoDetection): ...@@ -35,7 +34,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
def __getitem__(self, idx): def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx) img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx] image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target} target = {"image_id": image_id, "annotations": target}
img, target = self.prepare(img, target) img, target = self.prepare(img, target)
if self._transforms is not None: if self._transforms is not None:
img, target = self._transforms(img, target) img, target = self._transforms(img, target)
...@@ -71,7 +70,7 @@ class ConvertCocoPolysToMask(object): ...@@ -71,7 +70,7 @@ class ConvertCocoPolysToMask(object):
anno = target["annotations"] anno = target["annotations"]
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno] boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing # guard against no boxes via resizing
...@@ -114,7 +113,9 @@ class ConvertCocoPolysToMask(object): ...@@ -114,7 +113,9 @@ class ConvertCocoPolysToMask(object):
# for conversion to coco api # for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno]) area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) iscrowd = torch.tensor(
[obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]
)
target["area"] = area[keep] target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep] target["iscrowd"] = iscrowd[keep]
...@@ -126,52 +127,71 @@ class ConvertCocoPolysToMask(object): ...@@ -126,52 +127,71 @@ class ConvertCocoPolysToMask(object):
def make_coco_transforms(image_set): def make_coco_transforms(image_set):
normalize = T.Compose([ normalize = T.Compose(
T.ToTensor(), [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) )
])
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
if image_set == 'train': if image_set == "train":
return T.Compose([ return T.Compose(
T.RandomHorizontalFlip(), [
T.RandomSelect( T.RandomHorizontalFlip(),
T.RandomResize(scales, max_size=1333), T.RandomSelect(
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.RandomResize(scales, max_size=1333), T.RandomResize(scales, max_size=1333),
]) T.Compose(
), [
normalize, T.RandomResize([400, 500, 600]),
]) T.RandomSizeCrop(384, 600),
T.RandomResize(scales, max_size=1333),
if image_set == 'val': ]
return T.Compose([ ),
T.RandomResize([800], max_size=1333), ),
normalize, normalize,
]) ]
)
raise ValueError(f'unknown {image_set}')
if image_set == "val":
return T.Compose(
[
T.RandomResize([800], max_size=1333),
normalize,
]
)
raise ValueError(f"unknown {image_set}")
def build(image_set, args): def build(image_set, args):
if "manifold" in args.coco_path: if "manifold" in args.coco_path:
root = args.coco_path root = args.coco_path
PATHS = { PATHS = {
"train": (os.path.join(root, "coco_train2017"), "manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/instances_train2017.json"), "train": (
"val": (os.path.join(root, "coco_val2017"), "manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/instances_val2017.json"), os.path.join(root, "coco_train2017"),
"manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/instances_train2017.json",
),
"val": (
os.path.join(root, "coco_val2017"),
"manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/instances_val2017.json",
),
} }
else: else:
root = Path(args.coco_path) root = Path(args.coco_path)
assert root.exists(), f'provided COCO path {root} does not exist' assert root.exists(), f"provided COCO path {root} does not exist"
mode = 'instances' mode = "instances"
PATHS = { PATHS = {
"train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), "train": (
"val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), root / "train2017",
root / "annotations" / f"{mode}_train2017.json",
),
"val": (root / "val2017", root / "annotations" / f"{mode}_val2017.json"),
} }
img_folder, ann_file = PATHS[image_set] img_folder, ann_file = PATHS[image_set]
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) dataset = CocoDetection(
img_folder,
ann_file,
transforms=make_coco_transforms(image_set),
return_masks=args.masks,
)
return dataset return dataset
...@@ -8,17 +8,16 @@ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references ...@@ -8,17 +8,16 @@ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references
The difference is that there is less copy-pasting from pycocotools The difference is that there is less copy-pasting from pycocotools
in the end of the file, as python3 can suppress prints with contextlib in the end of the file, as python3 can suppress prints with contextlib
""" """
import os
import contextlib import contextlib
import copy import copy
import numpy as np import os
import torch
from pycocotools.cocoeval import COCOeval import numpy as np
from pycocotools.coco import COCO
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
import torch
from detr.util.misc import all_gather from detr.util.misc import all_gather
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
class CocoEvaluator(object): class CocoEvaluator(object):
...@@ -43,7 +42,7 @@ class CocoEvaluator(object): ...@@ -43,7 +42,7 @@ class CocoEvaluator(object):
results = self.prepare(predictions, iou_type) results = self.prepare(predictions, iou_type)
# suppress pycocotools prints # suppress pycocotools prints
with open(os.devnull, 'w') as devnull: with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull): with contextlib.redirect_stdout(devnull):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type] coco_eval = self.coco_eval[iou_type]
...@@ -57,7 +56,9 @@ class CocoEvaluator(object): ...@@ -57,7 +56,9 @@ class CocoEvaluator(object):
def synchronize_between_processes(self): def synchronize_between_processes(self):
for iou_type in self.iou_types: for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) create_common_coco_eval(
self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]
)
def accumulate(self): def accumulate(self):
for coco_eval in self.coco_eval.values(): for coco_eval in self.coco_eval.values():
...@@ -118,7 +119,9 @@ class CocoEvaluator(object): ...@@ -118,7 +119,9 @@ class CocoEvaluator(object):
labels = prediction["labels"].tolist() labels = prediction["labels"].tolist()
rles = [ rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] mask_util.encode(
np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F")
)[0]
for mask in masks for mask in masks
] ]
for rle in rles: for rle in rles:
...@@ -155,7 +158,7 @@ class CocoEvaluator(object): ...@@ -155,7 +158,7 @@ class CocoEvaluator(object):
{ {
"image_id": original_id, "image_id": original_id,
"category_id": labels[k], "category_id": labels[k],
'keypoints': keypoint, "keypoints": keypoint,
"score": scores[k], "score": scores[k],
} }
for k, keypoint in enumerate(keypoints) for k, keypoint in enumerate(keypoints)
...@@ -208,17 +211,19 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs): ...@@ -208,17 +211,19 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
def evaluate(self): def evaluate(self):
''' """
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None :return: None
''' """
# tic = time.time() # tic = time.time()
# print('Running per image evaluation...') # print('Running per image evaluation...')
p = self.params p = self.params
# add backward compatibility if useSegm is specified in params # add backward compatibility if useSegm is specified in params
if p.useSegm is not None: if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox' p.iouType = "segm" if p.useSegm == 1 else "bbox"
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) print(
"useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)
)
# print('Evaluate annotation type *{}*'.format(p.iouType)) # print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds)) p.imgIds = list(np.unique(p.imgIds))
if p.useCats: if p.useCats:
...@@ -230,14 +235,15 @@ def evaluate(self): ...@@ -230,14 +235,15 @@ def evaluate(self):
# loop through images, area range, max detection number # loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1] catIds = p.catIds if p.useCats else [-1]
if p.iouType == 'segm' or p.iouType == 'bbox': if p.iouType == "segm" or p.iouType == "bbox":
computeIoU = self.computeIoU computeIoU = self.computeIoU
elif p.iouType == 'keypoints': elif p.iouType == "keypoints":
computeIoU = self.computeOks computeIoU = self.computeOks
self.ious = { self.ious = {
(imgId, catId): computeIoU(imgId, catId) (imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds for imgId in p.imgIds
for catId in catIds} for catId in catIds
}
evaluateImg = self.evaluateImg evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1] maxDet = p.maxDets[-1]
...@@ -254,6 +260,7 @@ def evaluate(self): ...@@ -254,6 +260,7 @@ def evaluate(self):
# print('DONE (t={:0.2f}s).'.format(toc-tic)) # print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs return p.imgIds, evalImgs
################################################################# #################################################################
# end of straight copy from pycocotools, just removing the prints # end of straight copy from pycocotools, just removing the prints
################################################################# #################################################################
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import json import json
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from PIL import Image from detectron2.utils.file_io import PathManager
from panopticapi.utils import rgb2id
from detr.util.box_ops import masks_to_boxes from detr.util.box_ops import masks_to_boxes
from panopticapi.utils import rgb2id
from PIL import Image
from .coco import make_coco_transforms from .coco import make_coco_transforms
from detectron2.utils.file_io import PathManager
class CocoPanoptic: class CocoPanoptic:
def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): def __init__(
with PathManager.open(ann_file, 'r') as f: self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True
):
with PathManager.open(ann_file, "r") as f:
self.coco = json.load(f) self.coco = json.load(f)
# sort 'images' field so that they are aligned with 'annotations' # sort 'images' field so that they are aligned with 'annotations'
# i.e., in alphabetical order # i.e., in alphabetical order
self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) self.coco["images"] = sorted(self.coco["images"], key=lambda x: x["id"])
# sanity check # sanity check
if "annotations" in self.coco: if "annotations" in self.coco:
for img, ann in zip(self.coco['images'], self.coco['annotations']): for img, ann in zip(self.coco["images"], self.coco["annotations"]):
assert img['file_name'][:-4] == ann['file_name'][:-4] assert img["file_name"][:-4] == ann["file_name"][:-4]
self.img_folder = img_folder self.img_folder = img_folder
self.ann_folder = ann_folder self.ann_folder = ann_folder
...@@ -36,37 +37,50 @@ class CocoPanoptic: ...@@ -36,37 +37,50 @@ class CocoPanoptic:
self.return_masks = return_masks self.return_masks = return_masks
def __getitem__(self, idx): def __getitem__(self, idx):
ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] ann_info = (
img_path = os.path.join(self.img_folder, ann_info['file_name'].replace('.png', '.jpg')) self.coco["annotations"][idx]
ann_path = os.path.join(self.ann_folder, ann_info['file_name']) if "annotations" in self.coco
else self.coco["images"][idx]
)
img_path = os.path.join(
self.img_folder, ann_info["file_name"].replace(".png", ".jpg")
)
ann_path = os.path.join(self.ann_folder, ann_info["file_name"])
with PathManager.open(img_path, "rb") as f: with PathManager.open(img_path, "rb") as f:
img = Image.open(f).convert('RGB') img = Image.open(f).convert("RGB")
w, h = img.size w, h = img.size
if "segments_info" in ann_info: if "segments_info" in ann_info:
with PathManager.open(ann_path, "rb") as f: with PathManager.open(ann_path, "rb") as f:
masks = np.asarray(Image.open(f), dtype=np.uint32) masks = np.asarray(Image.open(f), dtype=np.uint32)
masks = rgb2id(masks) masks = rgb2id(masks)
ids = np.array([ann['id'] for ann in ann_info['segments_info']]) ids = np.array([ann["id"] for ann in ann_info["segments_info"]])
masks = masks == ids[:, None, None] masks = masks == ids[:, None, None]
masks = torch.as_tensor(masks, dtype=torch.uint8) masks = torch.as_tensor(masks, dtype=torch.uint8)
labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) labels = torch.tensor(
[ann["category_id"] for ann in ann_info["segments_info"]],
dtype=torch.int64,
)
target = {} target = {}
target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) target["image_id"] = torch.tensor(
[ann_info["image_id"] if "image_id" in ann_info else ann_info["id"]]
)
if self.return_masks: if self.return_masks:
target['masks'] = masks target["masks"] = masks
target['labels'] = labels target["labels"] = labels
target["boxes"] = masks_to_boxes(masks) target["boxes"] = masks_to_boxes(masks)
target['size'] = torch.as_tensor([int(h), int(w)]) target["size"] = torch.as_tensor([int(h), int(w)])
target['orig_size'] = torch.as_tensor([int(h), int(w)]) target["orig_size"] = torch.as_tensor([int(h), int(w)])
if "segments_info" in ann_info: if "segments_info" in ann_info:
for name in ['iscrowd', 'area']: for name in ["iscrowd", "area"]:
target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) target[name] = torch.tensor(
[ann[name] for ann in ann_info["segments_info"]]
)
if self.transforms is not None: if self.transforms is not None:
img, target = self.transforms(img, target) img, target = self.transforms(img, target)
...@@ -74,12 +88,12 @@ class CocoPanoptic: ...@@ -74,12 +88,12 @@ class CocoPanoptic:
return img, target return img, target
def __len__(self): def __len__(self):
return len(self.coco['images']) return len(self.coco["images"])
def get_height_and_width(self, idx): def get_height_and_width(self, idx):
img_info = self.coco['images'][idx] img_info = self.coco["images"][idx]
height = img_info['height'] height = img_info["height"]
width = img_info['width'] width = img_info["width"]
return height, width return height, width
...@@ -87,28 +101,43 @@ def build(image_set, args): ...@@ -87,28 +101,43 @@ def build(image_set, args):
if "manifold" in args.coco_path: if "manifold" in args.coco_path:
root = args.coco_path root = args.coco_path
PATHS = { PATHS = {
"train": (os.path.join(root, "coco_train2017"), "manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/panoptic_train2017.json"), "train": (
"val": (os.path.join(root, "coco_val2017"), "manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/panoptic_val2017.json"), os.path.join(root, "coco_train2017"),
"manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/panoptic_train2017.json",
),
"val": (
os.path.join(root, "coco_val2017"),
"manifold://fair_vision_data/tree/detectron2/json_dataset_annotations/coco/panoptic_val2017.json",
),
} }
img_folder_path, ann_file = PATHS[image_set] img_folder_path, ann_file = PATHS[image_set]
ann_folder = os.path.join(root, f"coco_panoptic_{image_set}2017") ann_folder = os.path.join(root, f"coco_panoptic_{image_set}2017")
else: else:
img_folder_root = Path(args.coco_path) img_folder_root = Path(args.coco_path)
ann_folder_root = Path(args.coco_panoptic_path) ann_folder_root = Path(args.coco_panoptic_path)
assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' assert (
assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' img_folder_root.exists()
mode = 'panoptic' ), f"provided COCO path {img_folder_root} does not exist"
assert (
ann_folder_root.exists()
), f"provided COCO path {ann_folder_root} does not exist"
mode = "panoptic"
PATHS = { PATHS = {
"train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), "train": ("train2017", Path("annotations") / f"{mode}_train2017.json"),
"val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), "val": ("val2017", Path("annotations") / f"{mode}_val2017.json"),
} }
img_folder, ann_file = PATHS[image_set] img_folder, ann_file = PATHS[image_set]
img_folder_path = img_folder_root / img_folder img_folder_path = img_folder_root / img_folder
ann_folder = ann_folder_root / f'{mode}_{img_folder}' ann_folder = ann_folder_root / f"{mode}_{img_folder}"
ann_file = ann_folder_root / ann_file ann_file = ann_folder_root / ann_file
dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, dataset = CocoPanoptic(
transforms=make_coco_transforms(image_set), return_masks=args.masks) img_folder_path,
ann_folder,
ann_file,
transforms=make_coco_transforms(image_set),
return_masks=args.masks,
)
return dataset return dataset
...@@ -25,7 +25,9 @@ class PanopticEvaluator(object): ...@@ -25,7 +25,9 @@ class PanopticEvaluator(object):
def update(self, predictions): def update(self, predictions):
for p in predictions: for p in predictions:
with PathManager.open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: with PathManager.open(
os.path.join(self.output_dir, p["file_name"]), "wb"
) as f:
f.write(p.pop("png_string")) f.write(p.pop("png_string"))
self.predictions += predictions self.predictions += predictions
...@@ -43,5 +45,10 @@ class PanopticEvaluator(object): ...@@ -43,5 +45,10 @@ class PanopticEvaluator(object):
predictions_json = os.path.join(self.output_dir, "predictions.json") predictions_json = os.path.join(self.output_dir, "predictions.json")
with PathManager.open(predictions_json, "w") as f: with PathManager.open(predictions_json, "w") as f:
f.write(json.dumps(json_data)) f.write(json.dumps(json_data))
return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) return pq_compute(
self.gt_json,
predictions_json,
gt_folder=self.gt_folder,
pred_folder=self.output_dir,
)
return None return None
...@@ -10,7 +10,6 @@ import PIL ...@@ -10,7 +10,6 @@ import PIL
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from detr.util.box_ops import box_xyxy_to_cxcywh from detr.util.box_ops import box_xyxy_to_cxcywh
from detr.util.misc import interpolate from detr.util.misc import interpolate
...@@ -39,7 +38,7 @@ def crop(image, target, region): ...@@ -39,7 +38,7 @@ def crop(image, target, region):
if "masks" in target: if "masks" in target:
# FIXME should we update the area here if there are no boxes? # FIXME should we update the area here if there are no boxes?
target['masks'] = target['masks'][:, i:i + h, j:j + w] target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks") fields.append("masks")
# remove elements for which the boxes or masks that have zero area # remove elements for which the boxes or masks that have zero area
...@@ -47,10 +46,10 @@ def crop(image, target, region): ...@@ -47,10 +46,10 @@ def crop(image, target, region):
# favor boxes selection when defining which elements to keep # favor boxes selection when defining which elements to keep
# this is compatible with previous implementation # this is compatible with previous implementation
if "boxes" in target: if "boxes" in target:
cropped_boxes = target['boxes'].reshape(-1, 2, 2) cropped_boxes = target["boxes"].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else: else:
keep = target['masks'].flatten(1).any(1) keep = target["masks"].flatten(1).any(1)
for field in fields: for field in fields:
target[field] = target[field][keep] target[field] = target[field][keep]
...@@ -66,11 +65,13 @@ def hflip(image, target): ...@@ -66,11 +65,13 @@ def hflip(image, target):
target = target.copy() target = target.copy()
if "boxes" in target: if "boxes" in target:
boxes = target["boxes"] boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes target["boxes"] = boxes
if "masks" in target: if "masks" in target:
target['masks'] = target['masks'].flip(-1) target["masks"] = target["masks"].flip(-1)
return flipped_image, target return flipped_image, target
...@@ -110,13 +111,17 @@ def resize(image, target, size, max_size=None): ...@@ -110,13 +111,17 @@ def resize(image, target, size, max_size=None):
if target is None: if target is None:
return rescaled_image, None return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) ratios = tuple(
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
)
ratio_width, ratio_height = ratios ratio_width, ratio_height = ratios
target = target.copy() target = target.copy()
if "boxes" in target: if "boxes" in target:
boxes = target["boxes"] boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height]
)
target["boxes"] = scaled_boxes target["boxes"] = scaled_boxes
if "area" in target: if "area" in target:
...@@ -128,8 +133,10 @@ def resize(image, target, size, max_size=None): ...@@ -128,8 +133,10 @@ def resize(image, target, size, max_size=None):
target["size"] = torch.tensor([h, w]) target["size"] = torch.tensor([h, w])
if "masks" in target: if "masks" in target:
target['masks'] = interpolate( target["masks"] = (
target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
> 0.5
)
return rescaled_image, target return rescaled_image, target
...@@ -143,7 +150,9 @@ def pad(image, target, padding): ...@@ -143,7 +150,9 @@ def pad(image, target, padding):
# should we do something wrt the original size? # should we do something wrt the original size?
target["size"] = torch.tensor(padded_image.size[::-1]) target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target: if "masks" in target:
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) target["masks"] = torch.nn.functional.pad(
target["masks"], (0, padding[0], 0, padding[1])
)
return padded_image, target return padded_image, target
...@@ -161,7 +170,7 @@ class RandomSizeCrop(object): ...@@ -161,7 +170,7 @@ class RandomSizeCrop(object):
self.min_size = min_size self.min_size = min_size
self.max_size = max_size self.max_size = max_size
def __call__(self, img: PIL.Image.Image, target: dict): #noqa: P210 def __call__(self, img: PIL.Image.Image, target: dict): # noqa: P210
w = random.randint(self.min_size, min(img.width, self.max_size)) w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size)) h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w]) region = T.RandomCrop.get_params(img, [h, w])
...@@ -175,8 +184,8 @@ class CenterCrop(object): ...@@ -175,8 +184,8 @@ class CenterCrop(object):
def __call__(self, img, target): def __call__(self, img, target):
image_width, image_height = img.size image_width, image_height = img.size
crop_height, crop_width = self.size crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.)) crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.)) crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
...@@ -216,6 +225,7 @@ class RandomSelect(object): ...@@ -216,6 +225,7 @@ class RandomSelect(object):
Randomly selects between transforms1 and transforms2, Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2 with probability p for transforms1 and (1 - p) for transforms2
""" """
def __init__(self, transforms1, transforms2, p=0.5): def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1 self.transforms1 = transforms1
self.transforms2 = transforms2 self.transforms2 = transforms2
...@@ -233,7 +243,6 @@ class ToTensor(object): ...@@ -233,7 +243,6 @@ class ToTensor(object):
class RandomErasing(object): class RandomErasing(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs) self.eraser = T.RandomErasing(*args, **kwargs)
......
...@@ -7,22 +7,29 @@ import os ...@@ -7,22 +7,29 @@ import os
import sys import sys
from typing import Iterable from typing import Iterable
import torch
import detr.util.misc as utils import detr.util.misc as utils
import torch
from detr.datasets.coco_eval import CocoEvaluator from detr.datasets.coco_eval import CocoEvaluator
from detr.datasets.panoptic_eval import PanopticEvaluator from detr.datasets.panoptic_eval import PanopticEvaluator
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, def train_one_epoch(
data_loader: Iterable, optimizer: torch.optim.Optimizer, model: torch.nn.Module,
device: torch.device, epoch: int, max_norm: float = 0): criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
max_norm: float = 0,
):
model.train() model.train()
criterion.train() criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter(
header = 'Epoch: [{}]'.format(epoch) "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
)
header = "Epoch: [{}]".format(epoch)
print_freq = 10 print_freq = 10
for samples, targets in metric_logger.log_every(data_loader, print_freq, header): for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
...@@ -32,14 +39,20 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, ...@@ -32,14 +39,20 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
outputs = model(samples) outputs = model(samples)
loss_dict = criterion(outputs, targets) loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) losses = sum(
loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
)
# reduce losses over all GPUs for logging purposes # reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v loss_dict_reduced_unscaled = {
for k, v in loss_dict_reduced.items()} f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
loss_dict_reduced_scaled = {k: v * weight_dict[k] }
for k, v in loss_dict_reduced.items() if k in weight_dict} loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item() loss_value = losses_reduced_scaled.item()
...@@ -55,8 +68,10 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, ...@@ -55,8 +68,10 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step() optimizer.step()
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(
metric_logger.update(class_error=loss_dict_reduced['class_error']) loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
)
metric_logger.update(class_error=loss_dict_reduced["class_error"])
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes # gather the stats from all processes
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
...@@ -65,20 +80,24 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, ...@@ -65,20 +80,24 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
@torch.no_grad() @torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): def evaluate(
model, criterion, postprocessors, data_loader, base_ds, device, output_dir
):
model.eval() model.eval()
criterion.eval() criterion.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter(
header = 'Test:' "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
)
header = "Test:"
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types) coco_evaluator = CocoEvaluator(base_ds, iou_types)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
panoptic_evaluator = None panoptic_evaluator = None
if 'panoptic' in postprocessors.keys(): if "panoptic" in postprocessors.keys():
panoptic_evaluator = PanopticEvaluator( panoptic_evaluator = PanopticEvaluator(
data_loader.dataset.ann_file, data_loader.dataset.ann_file,
data_loader.dataset.ann_folder, data_loader.dataset.ann_folder,
...@@ -95,26 +114,39 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out ...@@ -95,26 +114,39 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
# reduce losses over all GPUs for logging purposes # reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k] loss_dict_reduced_scaled = {
for k, v in loss_dict_reduced.items() if k in weight_dict} k: v * weight_dict[k]
loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()
for k, v in loss_dict_reduced.items()} if k in weight_dict
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), }
**loss_dict_reduced_scaled, loss_dict_reduced_unscaled = {
**loss_dict_reduced_unscaled) f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
metric_logger.update(class_error=loss_dict_reduced['class_error']) }
metric_logger.update(
loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled,
)
metric_logger.update(class_error=loss_dict_reduced["class_error"])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes) results = postprocessors["bbox"](outputs, orig_target_sizes)
if 'segm' in postprocessors.keys(): if "segm" in postprocessors.keys():
target_sizes = torch.stack([t["size"] for t in targets], dim=0) target_sizes = torch.stack([t["size"] for t in targets], dim=0)
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) results = postprocessors["segm"](
res = {target['image_id'].item(): output for target, output in zip(targets, results)} results, outputs, orig_target_sizes, target_sizes
)
res = {
target["image_id"].item(): output
for target, output in zip(targets, results)
}
if coco_evaluator is not None: if coco_evaluator is not None:
coco_evaluator.update(res) coco_evaluator.update(res)
if panoptic_evaluator is not None: if panoptic_evaluator is not None:
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) res_pano = postprocessors["panoptic"](
outputs, target_sizes, orig_target_sizes
)
for i, target in enumerate(targets): for i, target in enumerate(targets):
image_id = target["image_id"].item() image_id = target["image_id"].item()
file_name = f"{image_id:012d}.png" file_name = f"{image_id:012d}.png"
...@@ -140,12 +172,12 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out ...@@ -140,12 +172,12 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
panoptic_res = panoptic_evaluator.summarize() panoptic_res = panoptic_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None: if coco_evaluator is not None:
if 'bbox' in postprocessors.keys(): if "bbox" in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
if 'segm' in postprocessors.keys(): if "segm" in postprocessors.keys():
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
if panoptic_res is not None: if panoptic_res is not None:
stats['PQ_all'] = panoptic_res["All"] stats["PQ_all"] = panoptic_res["All"]
stats['PQ_th'] = panoptic_res["Things"] stats["PQ_th"] = panoptic_res["Things"]
stats['PQ_st'] = panoptic_res["Stuff"] stats["PQ_st"] = panoptic_res["Stuff"]
return stats, coco_evaluator return stats, coco_evaluator
...@@ -9,4 +9,3 @@ ...@@ -9,4 +9,3 @@
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction from .ms_deform_attn_func import MSDeformAttnFunction
...@@ -9,17 +9,16 @@ ...@@ -9,17 +9,16 @@
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function
from __future__ import division from __future__ import division
from __future__ import print_function
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from detr import _C as MSDA
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
from detr import _C as MSDA
class MSDeformAttnFunction(Function): class MSDeformAttnFunction(Function):
...@@ -32,26 +31,60 @@ class MSDeformAttnFunction(Function): ...@@ -32,26 +31,60 @@ class MSDeformAttnFunction(Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): def forward(
ctx,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward( output = MSDA.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) value,
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step,
)
ctx.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output return output
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors (
grad_value, grad_sampling_loc, grad_attn_weight = \ value,
MSDA.ms_deform_attn_backward( value_spatial_shapes,
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) value_level_start_index,
sampling_locations,
attention_weights,
) = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): def ms_deform_attn_core_pytorch(
value, value_spatial_shapes, sampling_locations, attention_weights
):
# for debug and test only, # for debug and test only,
# need to use cuda version instead # need to use cuda version instead
# value shape (N, K, num_heads, channels_per_head) # value shape (N, K, num_heads, channels_per_head)
...@@ -64,14 +97,27 @@ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, ...@@ -64,14 +97,27 @@ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations,
sampling_value_list = [] sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes): for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) value_l_ = (
value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_ # N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, sampling_value_l_ = F.grid_sample(
mode='bilinear', padding_mode='zeros', align_corners=False) value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_) sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) N_ * M_, 1, Lq_, L_ * P_
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(N_, M_ * D_, Lq_)
)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch import torch
from detr.models.backbone import Backbone, Joiner from detr.models.backbone import Backbone, Joiner
from detr.models.detr import DETR, PostProcess from detr.models.detr import DETR, PostProcess
from detr.models.position_encoding import PositionEmbeddingSine from detr.models.position_encoding import PositionEmbeddingSine
...@@ -14,12 +13,16 @@ dependencies = ["torch", "torchvision"] ...@@ -14,12 +13,16 @@ dependencies = ["torch", "torchvision"]
def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False): def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False):
hidden_dim = 256 hidden_dim = 256
backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation) backbone = Backbone(
backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation
)
pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True) pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
backbone_with_pos_enc = Joiner(backbone, pos_enc) backbone_with_pos_enc = Joiner(backbone, pos_enc)
backbone_with_pos_enc.num_channels = backbone.num_channels backbone_with_pos_enc.num_channels = backbone.num_channels
transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True) transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True)
detr = DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100) detr = DETR(
backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100
)
if mask: if mask:
return DETRsegm(detr) return DETRsegm(detr)
return detr return detr
...@@ -34,7 +37,9 @@ def detr_resnet50(pretrained=False, num_classes=91, return_postprocessor=False): ...@@ -34,7 +37,9 @@ def detr_resnet50(pretrained=False, num_classes=91, return_postprocessor=False):
model = _make_detr("resnet50", dilation=False, num_classes=num_classes) model = _make_detr("resnet50", dilation=False, num_classes=num_classes)
if pretrained: if pretrained:
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth", map_location="cpu", check_hash=True url="https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth",
map_location="cpu",
check_hash=True,
) )
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if return_postprocessor: if return_postprocessor:
...@@ -53,7 +58,9 @@ def detr_resnet50_dc5(pretrained=False, num_classes=91, return_postprocessor=Fal ...@@ -53,7 +58,9 @@ def detr_resnet50_dc5(pretrained=False, num_classes=91, return_postprocessor=Fal
model = _make_detr("resnet50", dilation=True, num_classes=num_classes) model = _make_detr("resnet50", dilation=True, num_classes=num_classes)
if pretrained: if pretrained:
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth", map_location="cpu", check_hash=True url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth",
map_location="cpu",
check_hash=True,
) )
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if return_postprocessor: if return_postprocessor:
...@@ -70,7 +77,9 @@ def detr_resnet101(pretrained=False, num_classes=91, return_postprocessor=False) ...@@ -70,7 +77,9 @@ def detr_resnet101(pretrained=False, num_classes=91, return_postprocessor=False)
model = _make_detr("resnet101", dilation=False, num_classes=num_classes) model = _make_detr("resnet101", dilation=False, num_classes=num_classes)
if pretrained: if pretrained:
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth", map_location="cpu", check_hash=True url="https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth",
map_location="cpu",
check_hash=True,
) )
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if return_postprocessor: if return_postprocessor:
...@@ -89,7 +98,9 @@ def detr_resnet101_dc5(pretrained=False, num_classes=91, return_postprocessor=Fa ...@@ -89,7 +98,9 @@ def detr_resnet101_dc5(pretrained=False, num_classes=91, return_postprocessor=Fa
model = _make_detr("resnet101", dilation=True, num_classes=num_classes) model = _make_detr("resnet101", dilation=True, num_classes=num_classes)
if pretrained: if pretrained:
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth", map_location="cpu", check_hash=True url="https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth",
map_location="cpu",
check_hash=True,
) )
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if return_postprocessor: if return_postprocessor:
...@@ -101,10 +112,10 @@ def detr_resnet50_panoptic( ...@@ -101,10 +112,10 @@ def detr_resnet50_panoptic(
pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False
): ):
""" """
DETR R50 with 6 encoder and 6 decoder layers. DETR R50 with 6 encoder and 6 decoder layers.
Achieves 43.4 PQ on COCO val5k. Achieves 43.4 PQ on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction threshold is the minimum confidence required for keeping segments in the prediction
""" """
model = _make_detr("resnet50", dilation=False, num_classes=num_classes, mask=True) model = _make_detr("resnet50", dilation=False, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)} is_thing_map = {i: i <= 90 for i in range(250)}
...@@ -124,13 +135,13 @@ def detr_resnet50_dc5_panoptic( ...@@ -124,13 +135,13 @@ def detr_resnet50_dc5_panoptic(
pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False
): ):
""" """
DETR-DC5 R50 with 6 encoder and 6 decoder layers. DETR-DC5 R50 with 6 encoder and 6 decoder layers.
The last block of ResNet-50 has dilation to increase The last block of ResNet-50 has dilation to increase
output resolution. output resolution.
Achieves 44.6 on COCO val5k. Achieves 44.6 on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction threshold is the minimum confidence required for keeping segments in the prediction
""" """
model = _make_detr("resnet50", dilation=True, num_classes=num_classes, mask=True) model = _make_detr("resnet50", dilation=True, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)} is_thing_map = {i: i <= 90 for i in range(250)}
...@@ -150,11 +161,11 @@ def detr_resnet101_panoptic( ...@@ -150,11 +161,11 @@ def detr_resnet101_panoptic(
pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False
): ):
""" """
DETR-DC5 R101 with 6 encoder and 6 decoder layers. DETR-DC5 R101 with 6 encoder and 6 decoder layers.
Achieves 45.1 PQ on COCO val5k. Achieves 45.1 PQ on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction threshold is the minimum confidence required for keeping segments in the prediction
""" """
model = _make_detr("resnet101", dilation=False, num_classes=num_classes, mask=True) model = _make_detr("resnet101", dilation=False, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)} is_thing_map = {i: i <= 90 for i in range(250)}
......
...@@ -13,15 +13,14 @@ ...@@ -13,15 +13,14 @@
Backbone modules. Backbone modules.
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from detr.util.misc import NestedTensor, is_main_process
from torch import nn from torch import nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from detr.util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding from .position_encoding import build_position_encoding
...@@ -43,15 +42,29 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -43,15 +42,29 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_var", torch.ones(n))
self.eps = eps self.eps = eps
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(
missing_keys, unexpected_keys, error_msgs): self,
num_batches_tracked_key = prefix + 'num_batches_tracked' state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict: if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key] del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict( super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, state_dict,
missing_keys, unexpected_keys, error_msgs) prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x): def forward(self, x):
# move reshapes to the beginning # move reshapes to the beginning
...@@ -67,11 +80,17 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -67,11 +80,17 @@ class FrozenBatchNorm2d(torch.nn.Module):
class BackboneBase(nn.Module): class BackboneBase(nn.Module):
def __init__(
def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
):
super().__init__() super().__init__()
for name, parameter in backbone.named_parameters(): for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False) parameter.requires_grad_(False)
if return_interm_layers: if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
...@@ -79,7 +98,7 @@ class BackboneBase(nn.Module): ...@@ -79,7 +98,7 @@ class BackboneBase(nn.Module):
self.strides = [8, 16, 32] self.strides = [8, 16, 32]
self.num_channels = [512, 1024, 2048] self.num_channels = [512, 1024, 2048]
else: else:
return_layers = {'layer4': "0"} return_layers = {"layer4": "0"}
self.strides = [32] self.strides = [32]
self.num_channels = [2048] self.num_channels = [2048]
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
...@@ -97,15 +116,21 @@ class BackboneBase(nn.Module): ...@@ -97,15 +116,21 @@ class BackboneBase(nn.Module):
class Backbone(BackboneBase): class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm.""" """ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool, def __init__(
return_interm_layers: bool, self,
dilation: bool): name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool,
):
norm_layer = FrozenBatchNorm2d norm_layer = FrozenBatchNorm2d
backbone = getattr(torchvision.models, name)( backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation], replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=norm_layer) pretrained=is_main_process(),
assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" norm_layer=norm_layer,
)
assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
super().__init__(backbone, train_backbone, return_interm_layers) super().__init__(backbone, train_backbone, return_interm_layers)
if dilation: if dilation:
self.strides[-1] = self.strides[-1] // 2 self.strides[-1] = self.strides[-1] // 2
...@@ -139,6 +164,8 @@ def build_backbone(args): ...@@ -139,6 +164,8 @@ def build_backbone(args):
position_embedding = build_position_encoding(args) position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0 train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) backbone = Backbone(
args.backbone, train_backbone, return_interm_layers, args.dilation
)
model = Joiner(backbone, position_embedding) model = Joiner(backbone, position_embedding)
return model return model
...@@ -149,7 +149,9 @@ class DeformableDETR(nn.Module): ...@@ -149,7 +149,9 @@ class DeformableDETR(nn.Module):
for box_embed in self.bbox_embed: for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
self.transformer.encoder.bbox_embed = MLP(hidden_dim, hidden_dim, 4, bbox_embed_num_layers) self.transformer.encoder.bbox_embed = MLP(
hidden_dim, hidden_dim, 4, bbox_embed_num_layers
)
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
"""The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
......
...@@ -6,25 +6,43 @@ DETR model and criterion classes. ...@@ -6,25 +6,43 @@ DETR model and criterion classes.
""" """
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from detr.util import box_ops from detr.util import box_ops
from detr.util.misc import (NestedTensor, nested_tensor_from_tensor_list, from detr.util.misc import (
accuracy, get_world_size, interpolate, NestedTensor,
is_dist_avail_and_initialized) nested_tensor_from_tensor_list,
accuracy,
get_world_size,
interpolate,
is_dist_avail_and_initialized,
)
from torch import nn
from .backbone import build_backbone from .backbone import build_backbone
from .matcher import build_matcher from .matcher import build_matcher
from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, from .segmentation import (
dice_loss, sigmoid_focal_loss) DETRsegm,
from .transformer import build_transformer PostProcessPanoptic,
PostProcessSegm,
dice_loss,
sigmoid_focal_loss,
)
from .setcriterion import SetCriterion from .setcriterion import SetCriterion
from .transformer import build_transformer
class DETR(nn.Module): class DETR(nn.Module):
""" This is the DETR module that performs object detection """ """This is the DETR module that performs object detection"""
def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, use_focal_loss=False):
""" Initializes the model. def __init__(
self,
backbone,
transformer,
num_classes,
num_queries,
aux_loss=False,
use_focal_loss=False,
):
"""Initializes the model.
Parameters: Parameters:
backbone: torch module of the backbone to be used. See backbone.py backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py transformer: torch module of the transformer architecture. See transformer.py
...@@ -37,27 +55,31 @@ class DETR(nn.Module): ...@@ -37,27 +55,31 @@ class DETR(nn.Module):
self.num_queries = num_queries self.num_queries = num_queries
self.transformer = transformer self.transformer = transformer
hidden_dim = transformer.d_model hidden_dim = transformer.d_model
self.class_embed = nn.Linear(hidden_dim, num_classes if use_focal_loss else num_classes + 1) self.class_embed = nn.Linear(
hidden_dim, num_classes if use_focal_loss else num_classes + 1
)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries, hidden_dim) self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.input_proj = nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1) self.input_proj = nn.Conv2d(
backbone.num_channels[-1], hidden_dim, kernel_size=1
)
self.backbone = backbone self.backbone = backbone
self.aux_loss = aux_loss self.aux_loss = aux_loss
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements: It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries. - "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)] Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as - "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1], (center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding). relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box. See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer. dictionnaries containing the two above keys for each decoder layer.
""" """
if isinstance(samples, (list, torch.Tensor)): if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples) samples = nested_tensor_from_tensor_list(samples)
...@@ -68,16 +90,18 @@ class DETR(nn.Module): ...@@ -68,16 +90,18 @@ class DETR(nn.Module):
src, mask = features[-1].decompose() src, mask = features[-1].decompose()
assert mask is not None assert mask is not None
# hs shape (NUM_LAYER, B, S, hidden_dim) # hs shape (NUM_LAYER, B, S, hidden_dim)
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] hs = self.transformer(
self.input_proj(src), mask, self.query_embed.weight, pos[-1]
)[0]
# shape (NUM_LAYER, B, S, NUM_CLASS + 1) # shape (NUM_LAYER, B, S, NUM_CLASS + 1)
outputs_class = self.class_embed(hs) outputs_class = self.class_embed(hs)
# shape (NUM_LAYER, B, S, 4) # shape (NUM_LAYER, B, S, 4)
outputs_coord = self.bbox_embed(hs).sigmoid() outputs_coord = self.bbox_embed(hs).sigmoid()
# pred_logits shape (B, S, NUM_CLASS + 1) # pred_logits shape (B, S, NUM_CLASS + 1)
# pred_boxes shape (B, S, 4) # pred_boxes shape (B, S, 4)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.aux_loss: if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
return out return out
@torch.jit.unused @torch.jit.unused
...@@ -85,23 +109,25 @@ class DETR(nn.Module): ...@@ -85,23 +109,25 @@ class DETR(nn.Module):
# this is a workaround to make torchscript happy, as torchscript # this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such # doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list. # as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b} return [
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] {"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
class PostProcess(nn.Module): class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api""" """This module converts the model's output into the format expected by the coco api"""
@torch.no_grad() @torch.no_grad()
def forward(self, outputs, target_sizes): def forward(self, outputs, target_sizes):
""" Perform the computation """Perform the computation
Parameters: Parameters:
outputs: raw outputs of the model outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation) For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding For visualization, this should be the image size after data augment, but before padding
""" """
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
assert len(out_logits) == len(target_sizes) assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2 assert target_sizes.shape[1] == 2
...@@ -116,19 +142,24 @@ class PostProcess(nn.Module): ...@@ -116,19 +142,24 @@ class PostProcess(nn.Module):
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :] boxes = boxes * scale_fct[:, None, :]
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] results = [
{"scores": s, "labels": l, "boxes": b}
for s, l, b in zip(scores, labels, boxes)
]
return results return results
class MLP(nn.Module): class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)""" """Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1) h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x): def forward(self, x):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
...@@ -145,7 +176,7 @@ def build(args): ...@@ -145,7 +176,7 @@ def build(args):
# you should pass `num_classes` to be 2 (max_obj_id + 1). # you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion # For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
num_classes = 20 if args.dataset_file != 'coco' else 91 num_classes = 20 if args.dataset_file != "coco" else 91
if args.dataset_file == "coco_panoptic": if args.dataset_file == "coco_panoptic":
# for panoptic, we just add a num_classes that is large enough to hold # for panoptic, we just add a num_classes that is large enough to hold
# max_obj_id + 1, but the exact value doesn't really matter # max_obj_id + 1, but the exact value doesn't really matter
...@@ -166,8 +197,8 @@ def build(args): ...@@ -166,8 +197,8 @@ def build(args):
if args.masks: if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args) matcher = build_matcher(args)
weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} weight_dict = {"loss_ce": 1, "loss_bbox": args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef weight_dict["loss_giou"] = args.giou_loss_coef
if args.masks: if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef weight_dict["loss_dice"] = args.dice_loss_coef
...@@ -175,20 +206,27 @@ def build(args): ...@@ -175,20 +206,27 @@ def build(args):
if args.aux_loss: if args.aux_loss:
aux_weight_dict = {} aux_weight_dict = {}
for i in range(args.dec_layers - 1): for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict) weight_dict.update(aux_weight_dict)
losses = ['labels', 'boxes', 'cardinality'] losses = ["labels", "boxes", "cardinality"]
if args.masks: if args.masks:
losses += ["masks"] losses += ["masks"]
criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, criterion = SetCriterion(
eos_coef=args.eos_coef, losses=losses) num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=args.eos_coef,
losses=losses,
)
criterion.to(device) criterion.to(device)
postprocessors = {'bbox': PostProcess()} postprocessors = {"bbox": PostProcess()}
if args.masks: if args.masks:
postprocessors['segm'] = PostProcessSegm() postprocessors["segm"] = PostProcessSegm()
if args.dataset_file == "coco_panoptic": if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)} is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) postprocessors["panoptic"] = PostProcessPanoptic(
is_thing_map, threshold=0.85
)
return model, criterion, postprocessors return model, criterion, postprocessors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment