Unverified Commit a481f5a8 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Enhance] Move train_cfg test_cfg to model (#307)

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg/test_cfg to model

* Move train_cfg and test_cfg into model

* modify centerpoint configs

* Modify docs

* modify build_detector

* modify test_config_build_detector

* modify build_detector parameters

* Adopt the same strategy in build_detector
parent a347ac75
...@@ -30,7 +30,8 @@ def init_detector(config, checkpoint=None, device='cuda:0'): ...@@ -30,7 +30,8 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
raise TypeError('config must be a filename or Config object, ' raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}') f'but got {type(config)}')
config.model.pretrained = None config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg) config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint) checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']: if 'CLASSES' in checkpoint['meta']:
......
import warnings
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build) ROI_EXTRACTORS, SHARED_HEADS, build)
from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS
...@@ -35,6 +37,14 @@ def build_loss(cfg): ...@@ -35,6 +37,14 @@ def build_loss(cfg):
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector.""" """Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......
...@@ -62,8 +62,8 @@ def _get_detector_cfg(fname): ...@@ -62,8 +62,8 @@ def _get_detector_cfg(fname):
import mmcv import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg)) train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg)) test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
model.update(train_cfg=train_cfg) model.update(train_cfg=train_cfg)
model.update(test_cfg=test_cfg) model.update(test_cfg=test_cfg)
......
...@@ -40,20 +40,17 @@ def _get_detector_cfg(fname): ...@@ -40,20 +40,17 @@ def _get_detector_cfg(fname):
These are deep copied to allow for safe modification of parameters without These are deep copied to allow for safe modification of parameters without
influencing other tests. influencing other tests.
""" """
import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg)) return model
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg))
return model, train_cfg, test_cfg
def _test_two_stage_forward(cfg_file): def _test_two_stage_forward(cfg_file):
model, train_cfg, test_cfg = _get_detector_cfg(cfg_file) model = _get_detector_cfg(cfg_file)
model['pretrained'] = None model['pretrained'] = None
from mmdet.models import build_detector from mmdet.models import build_detector
detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) detector = build_detector(model)
input_shape = (1, 3, 256, 256) input_shape = (1, 3, 256, 256)
...@@ -107,11 +104,11 @@ def _test_two_stage_forward(cfg_file): ...@@ -107,11 +104,11 @@ def _test_two_stage_forward(cfg_file):
def _test_single_stage_forward(cfg_file): def _test_single_stage_forward(cfg_file):
model, train_cfg, test_cfg = _get_detector_cfg(cfg_file) model = _get_detector_cfg(cfg_file)
model['pretrained'] = None model['pretrained'] = None
from mmdet.models import build_detector from mmdet.models import build_detector
detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) detector = build_detector(model)
input_shape = (1, 3, 300, 300) input_shape = (1, 3, 300, 300)
mm_inputs = _demo_mm_inputs(input_shape) mm_inputs = _demo_mm_inputs(input_shape)
......
...@@ -52,8 +52,8 @@ def _get_head_cfg(fname): ...@@ -52,8 +52,8 @@ def _get_head_cfg(fname):
import mmcv import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg)) train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg)) test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
bbox_head = model.bbox_head bbox_head = model.bbox_head
bbox_head.update(train_cfg=train_cfg) bbox_head.update(train_cfg=train_cfg)
...@@ -70,8 +70,8 @@ def _get_rpn_head_cfg(fname): ...@@ -70,8 +70,8 @@ def _get_rpn_head_cfg(fname):
import mmcv import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg)) train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg)) test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
rpn_head = model.rpn_head rpn_head = model.rpn_head
rpn_head.update(train_cfg=train_cfg.rpn) rpn_head.update(train_cfg=train_cfg.rpn)
...@@ -88,8 +88,8 @@ def _get_roi_head_cfg(fname): ...@@ -88,8 +88,8 @@ def _get_roi_head_cfg(fname):
import mmcv import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg)) train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg)) test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
roi_head = model.roi_head roi_head = model.roi_head
roi_head.update(train_cfg=train_cfg.rcnn) roi_head.update(train_cfg=train_cfg.rcnn)
...@@ -106,8 +106,8 @@ def _get_pts_bbox_head_cfg(fname): ...@@ -106,8 +106,8 @@ def _get_pts_bbox_head_cfg(fname):
import mmcv import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg.pts)) train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg.pts))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg.pts)) test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg.pts))
pts_bbox_head = model.pts_bbox_head pts_bbox_head = model.pts_bbox_head
pts_bbox_head.update(train_cfg=train_cfg) pts_bbox_head.update(train_cfg=train_cfg)
...@@ -124,8 +124,8 @@ def _get_vote_head_cfg(fname): ...@@ -124,8 +124,8 @@ def _get_vote_head_cfg(fname):
import mmcv import mmcv
config = _get_config_module(fname) config = _get_config_module(fname)
model = copy.deepcopy(config.model) model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.train_cfg)) train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.test_cfg)) test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
vote_head = model.bbox_head vote_head = model.bbox_head
vote_head.update(train_cfg=train_cfg) vote_head.update(train_cfg=train_cfg)
......
...@@ -51,7 +51,8 @@ def test_single_gpu_test(): ...@@ -51,7 +51,8 @@ def test_single_gpu_test():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda') pytest.skip('test requires GPU and torch+cuda')
cfg = _get_config_module('votenet/votenet_16x8_sunrgbd-3d-10class.py') cfg = _get_config_module('votenet/votenet_16x8_sunrgbd-3d-10class.py')
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
dataset_cfg = cfg.data.test dataset_cfg = cfg.data.test
dataset_cfg.data_root = './tests/data/sunrgbd' dataset_cfg.data_root = './tests/data/sunrgbd'
dataset_cfg.ann_file = 'tests/data/sunrgbd/sunrgbd_infos.pkl' dataset_cfg.ann_file = 'tests/data/sunrgbd/sunrgbd_infos.pkl'
......
...@@ -38,18 +38,15 @@ def test_config_build_detector(): ...@@ -38,18 +38,15 @@ def test_config_build_detector():
config_mod = Config.fromfile(config_fpath) config_mod = Config.fromfile(config_fpath)
config_mod.model config_mod.model
config_mod.train_cfg config_mod.model.train_cfg
config_mod.test_cfg config_mod.model.test_cfg
print('Building detector, config_fpath = {!r}'.format(config_fpath)) print('Building detector, config_fpath = {!r}'.format(config_fpath))
# Remove pretrained keys to allow for testing in an offline environment # Remove pretrained keys to allow for testing in an offline environment
if 'pretrained' in config_mod.model: if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None config_mod.model['pretrained'] = None
detector = build_detector( detector = build_detector(config_mod.model)
config_mod.model,
train_cfg=config_mod.train_cfg,
test_cfg=config_mod.test_cfg)
assert detector is not None assert detector is not None
if 'roi_head' in config_mod.model.keys(): if 'roi_head' in config_mod.model.keys():
......
...@@ -48,7 +48,8 @@ def main(): ...@@ -48,7 +48,8 @@ def main():
shuffle=False) shuffle=False)
# build the model and load checkpoint # build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
wrap_fp16_model(model) wrap_fp16_model(model)
......
...@@ -77,7 +77,10 @@ def main(): ...@@ -77,7 +77,10 @@ def main():
checkpoint = torch.load(args.checkpoint) checkpoint = torch.load(args.checkpoint)
cfg = parse_config(checkpoint['meta']['config']) cfg = parse_config(checkpoint['meta']['config'])
# Build the model and load checkpoint # Build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
orig_ckpt = checkpoint['state_dict'] orig_ckpt = checkpoint['state_dict']
converted_ckpt = orig_ckpt.copy() converted_ckpt = orig_ckpt.copy()
......
...@@ -111,7 +111,8 @@ def main(): ...@@ -111,7 +111,8 @@ def main():
shuffle=False) shuffle=False)
# build the model and load checkpoint # build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
wrap_fp16_model(model) wrap_fp16_model(model)
......
...@@ -136,7 +136,10 @@ def main(): ...@@ -136,7 +136,10 @@ def main():
meta['seed'] = args.seed meta['seed'] = args.seed
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
logger.info(f'Model:\n{model}') logger.info(f'Model:\n{model}')
datasets = [build_dataset(cfg.data.train)] datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2: if len(cfg.workflow) == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment