Commit 32dbb035 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

add EgoDETRRunner

Summary:
Major changes
- Add a new runner `EgoDETRRunner` which inherit from existing `DETRRunner` in D2 (https://github.com/facebookresearch/d2go/commit/62c21f252ad314961cf0157ee8f37cc4f7835e1d)GO repo.
- Add a new data mapper `EgoDETRDatasetMapper` which has custom crop transform generator and supports generic data augmentation.

Reviewed By: zhanghang1989

Differential Revision: D28895225

fbshipit-source-id: 4181ff8fce81df22a01d355fdff7e81e83d69e64
parent 62c21f25
...@@ -28,7 +28,8 @@ def add_detr_config(cfg): ...@@ -28,7 +28,8 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.L1_WEIGHT = 5.0 cfg.MODEL.DETR.L1_WEIGHT = 5.0
cfg.MODEL.DETR.DEEP_SUPERVISION = True cfg.MODEL.DETR.DEEP_SUPERVISION = True
cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1
cfg.MODEL.DETR.WITH_BOX_REFINE = False
cfg.MODEL.DETR.TWO_STAGE = False
# TRANSFORMER # TRANSFORMER
cfg.MODEL.DETR.NHEADS = 8 cfg.MODEL.DETR.NHEADS = 8
......
...@@ -14,7 +14,7 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper): ...@@ -14,7 +14,7 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper):
def __init__(self, cfg, is_train=True, image_loader=None, tfm_gens=None): def __init__(self, cfg, is_train=True, image_loader=None, tfm_gens=None):
self.image_loader = None self.image_loader = None
self.backfill_size = False self.backfill_size = False
self.retry = 3 self.retry = 3
self.catch_exception = True self.catch_exception = True
self._error_count = 0 self._error_count = 0
self._total_counts = 0 self._total_counts = 0
...@@ -52,8 +52,8 @@ class DETRRunner(GeneralizedRCNNRunner): ...@@ -52,8 +52,8 @@ class DETRRunner(GeneralizedRCNNRunner):
memo.add(value) memo.add(value)
lr = cfg.SOLVER.BASE_LR lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "backbone.0" in key: if "backbone.0" in key or "reference_points" in key or "sampling_offsets" in key:
lr = lr * 0.1 #cfg.SOLVER.BACKBONE_MULTIPLIER lr = lr * 0.1
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment