Commit 4985ef73 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Refactor Code Base

Summary: Add DETR_MODEL_REGISTRY registry to better support different variant of DETR (in later diff).

Reviewed By: newstzpz

Differential Revision: D32874194

fbshipit-source-id: f8e9a61417ec66bec9f2d98631260a2f4e2af4cf
parent 189d83d7
MODEL:
META_ARCHITECTURE: "Detr"
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_ON: False
BACKBONE:
NAME: "FBNetV2C4Backbone"
FBNET_V2:
ARCH: "FBNetV3_A_dsmask_C5"
NORM: "sync_bn"
WIDTH_DIVISOR: 8
SCALE_FACTOR: 1.0
OUT_FEATURES: ["trunk4"]
DETR:
NAME: "DETR"
GIOU_WEIGHT: 2.0
L1_WEIGHT: 5.0
NUM_OBJECT_QUERIES: 100
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
SOLVER:
IMS_PER_BATCH: 16
BASE_LR: 0.0001
STEPS: (1478400,)
MAX_ITER: 2217600
WARMUP_FACTOR: 1.0
WARMUP_ITERS: 10
WEIGHT_DECAY: 0.0001
OPTIMIZER: "ADAMW"
CLIP_GRADIENTS:
ENABLED: True
CLIP_TYPE: "full_model"
CLIP_VALUE: 0.1
NORM_TYPE: 2.0
LR_MULTIPLIER_OVERWRITE: [{'backbone': 0.1}]
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
CROP:
ENABLED: True
TYPE: "absolute_range"
SIZE: (384, 600)
FORMAT: "RGB"
D2GO_DATA:
MAPPER:
NAME: "DETRDatasetMapper"
TEST:
EVAL_PERIOD: 4000
DATALOADER:
FILTER_EMPTY_ANNOTATIONS: False
NUM_WORKERS: 4
VERSION: 2
...@@ -10,7 +10,10 @@ def add_detr_config(cfg): ...@@ -10,7 +10,10 @@ def add_detr_config(cfg):
Add config for DETR. Add config for DETR.
""" """
cfg.MODEL.DETR = CN() cfg.MODEL.DETR = CN()
cfg.MODEL.DETR.NAME = "DETR"
cfg.MODEL.DETR.NUM_CLASSES = 80 cfg.MODEL.DETR.NUM_CLASSES = 80
# simple backbone
cfg.MODEL.BACKBONE.SIMPLE = False cfg.MODEL.BACKBONE.SIMPLE = False
cfg.MODEL.BACKBONE.STRIDE = 1 cfg.MODEL.BACKBONE.STRIDE = 1
cfg.MODEL.BACKBONE.CHANNEL = 0 cfg.MODEL.BACKBONE.CHANNEL = 0
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess from detectron2.modeling import META_ARCH_REGISTRY, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.structures import Boxes, ImageList, Instances, BitMasks
from detr.datasets.coco import convert_coco_poly_to_mask from detr.datasets.coco import convert_coco_poly_to_mask
from detr.models.backbone import Joiner from detr.models.backbone import Joiner
from detr.models.build import build_detr_model
from detr.models.deformable_detr import DeformableDETR from detr.models.deformable_detr import DeformableDETR
from detr.models.deformable_transformer import DeformableTransformer from detr.models.deformable_transformer import DeformableTransformer
from detr.models.detr import DETR from detr.models.detr import DETR
...@@ -14,7 +15,6 @@ from detr.models.matcher import HungarianMatcher ...@@ -14,7 +15,6 @@ from detr.models.matcher import HungarianMatcher
from detr.models.position_encoding import PositionEmbeddingSine from detr.models.position_encoding import PositionEmbeddingSine
from detr.models.segmentation import DETRsegm, PostProcessSegm from detr.models.segmentation import DETRsegm, PostProcessSegm
from detr.models.setcriterion import SetCriterion, FocalLossSetCriterion from detr.models.setcriterion import SetCriterion, FocalLossSetCriterion
from detr.models.transformer import Transformer
from detr.util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh from detr.util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from detr.util.misc import NestedTensor from detr.util.misc import NestedTensor
from torch import nn from torch import nn
...@@ -22,117 +22,6 @@ from torch import nn ...@@ -22,117 +22,6 @@ from torch import nn
__all__ = ["Detr"] __all__ = ["Detr"]
class ResNetMaskedBackbone(nn.Module):
"""This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
backbone_shape = self.backbone.output_shape()
if cfg.MODEL.DETR.NUM_FEATURE_LEVELS > 1:
self.strides = [8, 16, 32]
else:
self.strides = [32]
if cfg.MODEL.RESNETS.RES5_DILATION == 2:
# fix dilation from d2
self.backbone.stages[-1][0].conv2.dilation = (1, 1)
self.backbone.stages[-1][0].conv2.padding = (1, 1)
self.strides[-1] = self.strides[-1] // 2
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.num_channels = [backbone_shape[k].channels for k in backbone_shape.keys()]
def forward(self, images):
features = self.backbone(images.tensor)
# one tensor per feature level. Each tensor has shape (B, maxH, maxW)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
for i, k in enumerate(features.keys()):
features[k] = NestedTensor(features[k], masks[i])
return features
def mask_out_padding(self, feature_shapes, image_sizes, device):
masks = []
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
: int(np.ceil(float(h) / self.feature_strides[idx])),
: int(np.ceil(float(w) / self.feature_strides[idx])),
] = 0
masks.append(masks_per_feature_level)
return masks
class FBNetMaskedBackbone(ResNetMaskedBackbone):
"""This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
nn.Module.__init__(self)
self.backbone = build_backbone(cfg)
self.out_features = cfg.MODEL.FBNET_V2.OUT_FEATURES
self.feature_strides = list(self.backbone._out_feature_strides.values())
self.num_channels = [
self.backbone._out_feature_channels[k] for k in self.out_features
]
self.strides = [
self.backbone._out_feature_strides[k] for k in self.out_features
]
def forward(self, images):
features = self.backbone(images.tensor)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
ret_features = {}
for i, k in enumerate(features.keys()):
if k in self.out_features:
ret_features[k] = NestedTensor(features[k], masks[i])
return ret_features
class SimpleSingleStageBackbone(ResNetMaskedBackbone):
"""This is a simple wrapper for single stage backbone,
please set the required configs:
cfg.MODEL.BACKBONE.SIMPLE == True,
cfg.MODEL.BACKBONE.STRIDE, cfg.MODEL.BACKBONE.CHANNEL
"""
def __init__(self, cfg):
nn.Module.__init__(self)
self.backbone = build_backbone(cfg)
self.out_features = ["out"]
assert cfg.MODEL.BACKBONE.SIMPLE is True
self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE]
self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL]
self.strides = [cfg.MODEL.BACKBONE.STRIDE]
def forward(self, images):
y = self.backbone(images.tensor)
masks = self.mask_out_padding(
[y.shape],
images.image_sizes,
images.tensor.device,
)
assert len(masks) == 1
ret_features = {}
ret_features[self.out_features[0]] = NestedTensor(y, masks[0])
return ret_features
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
class Detr(nn.Module): class Detr(nn.Module):
""" """
...@@ -143,18 +32,10 @@ class Detr(nn.Module): ...@@ -143,18 +32,10 @@ class Detr(nn.Module):
super().__init__() super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE) self.device = torch.device(cfg.MODEL.DEVICE)
self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
self.mask_on = cfg.MODEL.MASK_ON self.mask_on = cfg.MODEL.MASK_ON
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
nheads = cfg.MODEL.DETR.NHEADS
dropout = cfg.MODEL.DETR.DROPOUT
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
dec_layers = cfg.MODEL.DETR.DEC_LAYERS dec_layers = cfg.MODEL.DETR.DEC_LAYERS
pre_norm = cfg.MODEL.DETR.PRE_NORM
# Loss parameters: # Loss parameters:
giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
...@@ -162,75 +43,9 @@ class Detr(nn.Module): ...@@ -162,75 +43,9 @@ class Detr(nn.Module):
cls_weight = cfg.MODEL.DETR.CLS_WEIGHT cls_weight = cfg.MODEL.DETR.CLS_WEIGHT
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT
centered_position_encoding = cfg.MODEL.DETR.CENTERED_POSITION_ENCODIND
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
N_steps = hidden_dim // 2 self.detr = build_detr_model(cfg)
if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg)
elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = FBNetMaskedBackbone(cfg)
elif cfg.MODEL.BACKBONE.SIMPLE:
d2_backbone = SimpleSingleStageBackbone(cfg)
else:
raise NotImplementedError
backbone = Joiner(
d2_backbone,
PositionEmbeddingSine(
N_steps, normalize=True, centered=centered_position_encoding
),
)
backbone.num_channels = d2_backbone.num_channels
self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
if cfg.MODEL.DETR.DEFORMABLE:
transformer = DeformableTransformer(
d_model=hidden_dim,
nhead=nheads,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="relu",
return_intermediate_dec=True,
num_feature_levels=num_feature_levels,
dec_n_points=4,
enc_n_points=4,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
two_stage_num_proposals=num_queries,
)
self.detr = DeformableDETR(
backbone,
transformer,
num_classes=self.num_classes,
num_queries=num_queries,
num_feature_levels=num_feature_levels,
aux_loss=deep_supervision,
with_box_refine=cfg.MODEL.DETR.WITH_BOX_REFINE,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
)
else:
transformer = Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
)
self.detr = DETR(
backbone,
transformer,
num_classes=self.num_classes,
num_queries=num_queries,
aux_loss=deep_supervision,
use_focal_loss=self.use_focal_loss,
)
if self.mask_on: if self.mask_on:
frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
if frozen_weights != "": if frozen_weights != "":
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .detr import build from .detr import build
......
import numpy as np
import torch
from detectron2.modeling import build_backbone
from detectron2.utils.registry import Registry
from detr.models.backbone import Joiner
from detr.models.position_encoding import PositionEmbeddingSine
from detr.util.misc import NestedTensor
from torch import nn
DETR_MODEL_REGISTRY = Registry("DETR_MODEL")
def build_detr_backbone(cfg):
if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg)
elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = FBNetMaskedBackbone(cfg)
elif cfg.MODEL.BACKBONE.SIMPLE:
d2_backbone = SimpleSingleStageBackbone(cfg)
else:
raise NotImplementedError
N_steps = cfg.MODEL.DETR.HIDDEN_DIM // 2
centered_position_encoding = cfg.MODEL.DETR.CENTERED_POSITION_ENCODIND
backbone = Joiner(
d2_backbone,
PositionEmbeddingSine(
N_steps, normalize=True, centered=centered_position_encoding
),
)
backbone.num_channels = d2_backbone.num_channels
return backbone
def build_detr_model(cfg):
name = cfg.MODEL.DETR.NAME
return DETR_MODEL_REGISTRY.get(name)(cfg)
class ResNetMaskedBackbone(nn.Module):
"""This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
backbone_shape = self.backbone.output_shape()
if cfg.MODEL.DETR.NUM_FEATURE_LEVELS > 1:
self.strides = [8, 16, 32]
else:
self.strides = [32]
if cfg.MODEL.RESNETS.RES5_DILATION == 2:
# fix dilation from d2
self.backbone.stages[-1][0].conv2.dilation = (1, 1)
self.backbone.stages[-1][0].conv2.padding = (1, 1)
self.strides[-1] = self.strides[-1] // 2
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.num_channels = [backbone_shape[k].channels for k in backbone_shape.keys()]
def forward(self, images):
features = self.backbone(images.tensor)
# one tensor per feature level. Each tensor has shape (B, maxH, maxW)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
for i, k in enumerate(features.keys()):
features[k] = NestedTensor(features[k], masks[i])
return features
def mask_out_padding(self, feature_shapes, image_sizes, device):
masks = []
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
: int(np.ceil(float(h) / self.feature_strides[idx])),
: int(np.ceil(float(w) / self.feature_strides[idx])),
] = 0
masks.append(masks_per_feature_level)
return masks
class FBNetMaskedBackbone(ResNetMaskedBackbone):
"""This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
nn.Module.__init__(self)
self.backbone = build_backbone(cfg)
self.out_features = cfg.MODEL.FBNET_V2.OUT_FEATURES
self.feature_strides = list(self.backbone._out_feature_strides.values())
self.num_channels = [
self.backbone._out_feature_channels[k] for k in self.out_features
]
self.strides = [
self.backbone._out_feature_strides[k] for k in self.out_features
]
def forward(self, images):
features = self.backbone(images.tensor)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
ret_features = {}
for i, k in enumerate(features.keys()):
if k in self.out_features:
ret_features[k] = NestedTensor(features[k], masks[i])
return ret_features
class SimpleSingleStageBackbone(ResNetMaskedBackbone):
"""This is a simple wrapper for single stage backbone,
please set the required configs:
cfg.MODEL.BACKBONE.SIMPLE == True,
cfg.MODEL.BACKBONE.STRIDE, cfg.MODEL.BACKBONE.CHANNEL
"""
def __init__(self, cfg):
nn.Module.__init__(self)
self.backbone = build_backbone(cfg)
self.out_features = ["out"]
assert cfg.MODEL.BACKBONE.SIMPLE is True
self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE]
self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL]
self.strides = [cfg.MODEL.BACKBONE.STRIDE]
def forward(self, images):
y = self.backbone(images.tensor)
masks = self.mask_out_padding(
[y.shape],
images.image_sizes,
images.tensor.device,
)
assert len(masks) == 1
ret_features = {}
ret_features[self.out_features[0]] = NestedTensor(y, masks[0])
return ret_features
...@@ -15,6 +15,7 @@ import math ...@@ -15,6 +15,7 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from detectron2.config import configurable
from torch import nn from torch import nn
from ..util import box_ops from ..util import box_ops
...@@ -27,7 +28,8 @@ from ..util.misc import ( ...@@ -27,7 +28,8 @@ from ..util.misc import (
is_dist_avail_and_initialized, is_dist_avail_and_initialized,
) )
from .backbone import build_backbone from .backbone import build_backbone
from .deformable_transformer import build_deforamble_transformer from .build import DETR_MODEL_REGISTRY, build_detr_backbone
from .deformable_transformer import DeformableTransformer
from .matcher import build_matcher from .matcher import build_matcher
from .segmentation import ( from .segmentation import (
DETRsegm, DETRsegm,
...@@ -43,9 +45,11 @@ def _get_clones(module, N): ...@@ -43,9 +45,11 @@ def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
@DETR_MODEL_REGISTRY.register()
class DeformableDETR(nn.Module): class DeformableDETR(nn.Module):
"""This is the Deformable DETR module that performs object detection""" """This is the Deformable DETR module that performs object detection"""
@configurable
def __init__( def __init__(
self, self,
backbone, backbone,
...@@ -153,6 +157,53 @@ class DeformableDETR(nn.Module): ...@@ -153,6 +157,53 @@ class DeformableDETR(nn.Module):
hidden_dim, hidden_dim, 4, bbox_embed_num_layers hidden_dim, hidden_dim, 4, bbox_embed_num_layers
) )
@classmethod
def from_config(cls, cfg):
num_classes = cfg.MODEL.DETR.NUM_CLASSES
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
nheads = cfg.MODEL.DETR.NHEADS
dropout = cfg.MODEL.DETR.DROPOUT
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
dec_layers = cfg.MODEL.DETR.DEC_LAYERS
pre_norm = cfg.MODEL.DETR.PRE_NORM
# Loss parameters:
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
backbone = build_detr_backbone(cfg)
transformer = DeformableTransformer(
d_model=hidden_dim,
nhead=nheads,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="relu",
return_intermediate_dec=True,
num_feature_levels=num_feature_levels,
dec_n_points=4,
enc_n_points=4,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
two_stage_num_proposals=num_queries,
)
return {
"backbone": backbone,
"transformer": transformer,
"num_classes": num_classes,
"num_queries": num_queries,
"num_feature_levels": num_feature_levels,
"aux_loss": deep_supervision,
"with_box_refine": cfg.MODEL.DETR.WITH_BOX_REFINE,
"two_stage": cfg.MODEL.DETR.TWO_STAGE,
}
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
"""The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
...@@ -337,58 +388,3 @@ class MLP(nn.Module): ...@@ -337,58 +388,3 @@ class MLP(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x return x
def build(args):
num_classes = 20 if args.dataset_file != "coco" else 91
if args.dataset_file == "coco_panoptic":
num_classes = 250
device = torch.device(args.device)
backbone = build_backbone(args)
transformer = build_deforamble_transformer(args)
model = DeformableDETR(
backbone,
transformer,
num_classes=num_classes,
num_queries=args.num_queries,
num_feature_levels=args.num_feature_levels,
aux_loss=args.aux_loss,
with_box_refine=args.with_box_refine,
two_stage=args.two_stage,
)
if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args)
weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef}
weight_dict["loss_giou"] = args.giou_loss_coef
if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "boxes", "cardinality"]
if args.masks:
losses += ["masks"]
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
criterion = FocalLossSetCriterion(
num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha
)
criterion.to(device)
postprocessors = {"bbox": PostProcess()}
if args.masks:
postprocessors["segm"] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(
is_thing_map, threshold=0.85
)
return model, criterion, postprocessors
...@@ -6,6 +6,7 @@ DETR model and criterion classes. ...@@ -6,6 +6,7 @@ DETR model and criterion classes.
""" """
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from detectron2.config import configurable
from detr.util import box_ops from detr.util import box_ops
from detr.util.misc import ( from detr.util.misc import (
NestedTensor, NestedTensor,
...@@ -18,6 +19,7 @@ from detr.util.misc import ( ...@@ -18,6 +19,7 @@ from detr.util.misc import (
from torch import nn from torch import nn
from .backbone import build_backbone from .backbone import build_backbone
from .build import DETR_MODEL_REGISTRY, build_detr_backbone
from .matcher import build_matcher from .matcher import build_matcher
from .segmentation import ( from .segmentation import (
DETRsegm, DETRsegm,
...@@ -27,12 +29,14 @@ from .segmentation import ( ...@@ -27,12 +29,14 @@ from .segmentation import (
sigmoid_focal_loss, sigmoid_focal_loss,
) )
from .setcriterion import SetCriterion from .setcriterion import SetCriterion
from .transformer import build_transformer from .transformer import Transformer, build_transformer
@DETR_MODEL_REGISTRY.register()
class DETR(nn.Module): class DETR(nn.Module):
"""This is the DETR module that performs object detection""" """This is the DETR module that performs object detection"""
@configurable
def __init__( def __init__(
self, self,
backbone, backbone,
...@@ -66,6 +70,46 @@ class DETR(nn.Module): ...@@ -66,6 +70,46 @@ class DETR(nn.Module):
self.backbone = backbone self.backbone = backbone
self.aux_loss = aux_loss self.aux_loss = aux_loss
@classmethod
def from_config(cls, cfg):
num_classes = cfg.MODEL.DETR.NUM_CLASSES
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
nheads = cfg.MODEL.DETR.NHEADS
dropout = cfg.MODEL.DETR.DROPOUT
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
dec_layers = cfg.MODEL.DETR.DEC_LAYERS
pre_norm = cfg.MODEL.DETR.PRE_NORM
# Loss parameters:
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
backbone = build_detr_backbone(cfg)
transformer = Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
)
return {
"backbone": backbone,
"transformer": transformer,
"num_classes": num_classes,
"num_queries": num_queries,
"aux_loss": deep_supervision,
"use_focal_loss": use_focal_loss,
}
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
"""The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment