Unverified Commit d95727b2 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add a field to support the evaluation interval (#849)

parent e4917130
......@@ -171,6 +171,7 @@ log_config = dict(
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
evaluation = dict(interval=1)
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
......
......@@ -91,8 +91,8 @@ def build_optimizer(model, optimizer_cfg):
paramwise_options = optimizer_cfg.pop('paramwise_options', None)
# if no paramwise option is specified, just use the global setting
if paramwise_options is None:
return obj_from_dict(
optimizer_cfg, torch.optim, dict(params=model.parameters()))
return obj_from_dict(optimizer_cfg, torch.optim,
dict(params=model.parameters()))
else:
assert isinstance(paramwise_options, dict)
# get base lr and weight decay
......@@ -154,15 +154,19 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
if validate:
val_dataset_cfg = cfg.data.val
eval_cfg = cfg.get('evaluation', {})
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
runner.register_hook(
CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
else:
dataset_type = getattr(datasets, val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
runner.register_hook(
CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
else:
runner.register_hook(DistEvalmAPHook(val_dataset_cfg))
runner.register_hook(
DistEvalmAPHook(val_dataset_cfg, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
......
......@@ -116,9 +116,11 @@ class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self,
dataset,
interval=1,
proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)):
super(CocoDistEvalRecallHook, self).__init__(dataset)
super(CocoDistEvalRecallHook, self).__init__(
dataset, interval=interval)
self.proposal_nums = np.array(proposal_nums, dtype=np.int32)
self.iou_thrs = np.array(iou_thrs, dtype=np.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment