Unverified Commit 51904cbc authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #559 from hellock/dev

Check dataset type by subclass instead of names
parents 86187a20 ebc83122
......@@ -6,6 +6,7 @@ import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet import datasets
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
......@@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset_cfg = cfg.data.val
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
else:
if cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
dataset_type = getattr(datasets, val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
else:
runner.register_hook(DistEvalmAPHook(cfg.data.val))
runner.register_hook(DistEvalmAPHook(val_dataset_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment