Unverified Commit a481f5a8 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Enhance] Move train_cfg test_cfg to model (#307)

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg and test_cfg into model

* modify centerpoint configs

* Modify docs

* modify build_detector

* modify test_config_build_detector

* modify build_detector parameters

* Adopt the same strategy in build_detector
parent a347ac75
......@@ -30,7 +30,8 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg)
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']:
......
import warnings
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build)
from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS
......@@ -35,6 +37,14 @@ def build_loss(cfg):
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......
......@@ -62,8 +62,8 @@ def _get_detector_cfg(fname):
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
model.update(train_cfg=train_cfg)
model.update(test_cfg=test_cfg)
......
......@@ -40,20 +40,17 @@ def _get_detector_cfg(fname):
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
return model, train_cfg, test_cfg
return model
def _test_two_stage_forward(cfg_file):
model, train_cfg, test_cfg = _get_detector_cfg(cfg_file)
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
from mmdet.models import build_detector
detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
detector = build_detector(model)
input_shape = (1, 3, 256, 256)
......@@ -107,11 +104,11 @@ def _test_two_stage_forward(cfg_file):
def _test_single_stage_forward(cfg_file):
model, train_cfg, test_cfg = _get_detector_cfg(cfg_file)
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
from mmdet.models import build_detector
detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
detector = build_detector(model)
input_shape = (1, 3, 300, 300)
mm_inputs = _demo_mm_inputs(input_shape)
......
......@@ -52,8 +52,8 @@ def _get_head_cfg(fname):
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
bbox_head = model.bbox_head
bbox_head.update(train_cfg=train_cfg)
......@@ -70,8 +70,8 @@ def _get_rpn_head_cfg(fname):
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
rpn_head = model.rpn_head
rpn_head.update(train_cfg=train_cfg.rpn)
......@@ -88,8 +88,8 @@ def _get_roi_head_cfg(fname):
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
roi_head = model.roi_head
roi_head.update(train_cfg=train_cfg.rcnn)
......@@ -106,8 +106,8 @@ def _get_pts_bbox_head_cfg(fname):
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg.pts))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg.pts))
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg.pts))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg.pts))
pts_bbox_head = model.pts_bbox_head
pts_bbox_head.update(train_cfg=train_cfg)
......@@ -124,8 +124,8 @@ def _get_vote_head_cfg(fname):
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
vote_head = model.bbox_head
vote_head.update(train_cfg=train_cfg)
......
......@@ -51,7 +51,8 @@ def test_single_gpu_test():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
cfg = _get_config_module('votenet/votenet_16x8_sunrgbd-3d-10class.py')
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
dataset_cfg = cfg.data.test
dataset_cfg.data_root = './tests/data/sunrgbd'
dataset_cfg.ann_file = 'tests/data/sunrgbd/sunrgbd_infos.pkl'
......
......@@ -38,18 +38,15 @@ def test_config_build_detector():
config_mod = Config.fromfile(config_fpath)
config_mod.model
config_mod.train_cfg
config_mod.test_cfg
config_mod.model.train_cfg
config_mod.model.test_cfg
print('Building detector, config_fpath = {!r}'.format(config_fpath))
# Remove pretrained keys to allow for testing in an offline environment
if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None
detector = build_detector(
config_mod.model,
train_cfg=config_mod.train_cfg,
test_cfg=config_mod.test_cfg)
detector = build_detector(config_mod.model)
assert detector is not None
if 'roi_head' in config_mod.model.keys():
......
......@@ -48,7 +48,8 @@ def main():
shuffle=False)
# build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
......
......@@ -77,7 +77,10 @@ def main():
checkpoint = torch.load(args.checkpoint)
cfg = parse_config(checkpoint['meta']['config'])
# Build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
orig_ckpt = checkpoint['state_dict']
converted_ckpt = orig_ckpt.copy()
......
......@@ -111,7 +111,8 @@ def main():
shuffle=False)
# build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
......
......@@ -136,7 +136,10 @@ def main():
meta['seed'] = args.seed
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
logger.info(f'Model:\n{model}')
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment