Commit c6fde230 authored by pangjm's avatar pangjm
Browse files

Merge branch 'master' of github.com:open-mmlab/mmdetection

Conflicts:
	tools/train.py
parents e74519bb 826a5613
...@@ -6,43 +6,58 @@ from .cpu_nms import cpu_nms ...@@ -6,43 +6,58 @@ from .cpu_nms import cpu_nms
from .cpu_soft_nms import cpu_soft_nms from .cpu_soft_nms import cpu_soft_nms
def nms(dets, thresh, device_id=None): def nms(dets, iou_thr, device_id=None):
"""Dispatch to either CPU or GPU NMS implementations.""" """Dispatch to either CPU or GPU NMS implementations."""
tensor_device = None
if isinstance(dets, torch.Tensor): if isinstance(dets, torch.Tensor):
tensor_device = dets.device is_tensor = True
if dets.is_cuda: if dets.is_cuda:
device_id = dets.get_device() device_id = dets.get_device()
dets = dets.detach().cpu().numpy() dets_np = dets.detach().cpu().numpy()
assert isinstance(dets, np.ndarray) elif isinstance(dets, np.ndarray):
is_tensor = False
dets_np = dets
else:
raise TypeError(
'dets must be either a Tensor or numpy array, but got {}'.format(
type(dets)))
if dets.shape[0] == 0: if dets_np.shape[0] == 0:
inds = [] inds = []
else: else:
inds = (gpu_nms(dets, thresh, device_id=device_id) inds = (gpu_nms(dets_np, iou_thr, device_id=device_id)
if device_id is not None else cpu_nms(dets, thresh)) if device_id is not None else cpu_nms(dets_np, iou_thr))
if tensor_device: if is_tensor:
return torch.Tensor(inds).long().to(tensor_device) inds = dets.new_tensor(inds, dtype=torch.long)
else: else:
return np.array(inds, dtype=np.int) inds = np.array(inds, dtype=np.int64)
return dets[inds, :], inds
def soft_nms(dets, Nt=0.3, method=1, sigma=0.5, min_score=0): def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
if isinstance(dets, torch.Tensor): if isinstance(dets, torch.Tensor):
_dets = dets.detach().cpu().numpy() is_tensor = True
dets_np = dets.detach().cpu().numpy()
elif isinstance(dets, np.ndarray):
is_tensor = False
dets_np = dets
else: else:
_dets = dets.copy() raise TypeError(
assert isinstance(_dets, np.ndarray) 'dets must be either a Tensor or numpy array, but got {}'.format(
type(dets)))
method_codes = {'linear': 1, 'gaussian': 2}
if method not in method_codes:
raise ValueError('Invalid method for SoftNMS: {}'.format(method))
new_dets, inds = cpu_soft_nms( new_dets, inds = cpu_soft_nms(
_dets, Nt=Nt, method=method, sigma=sigma, threshold=min_score) dets_np,
iou_thr,
if isinstance(dets, torch.Tensor): method=method_codes[method],
return dets.new_tensor( sigma=sigma,
inds, dtype=torch.long), dets.new_tensor(new_dets) min_score=min_score)
if is_tensor:
return dets.new_tensor(new_dets), dets.new_tensor(
inds, dtype=torch.long)
else: else:
return np.array( return new_dets.astype(np.float32), inds.astype(np.int64)
inds, dtype=np.int), np.array(
new_dets, dtype=np.float32)
...@@ -12,7 +12,7 @@ def readme(): ...@@ -12,7 +12,7 @@ def readme():
MAJOR = 0 MAJOR = 0
MINOR = 5 MINOR = 5
PATCH = 2 PATCH = 4
SUFFIX = '' SUFFIX = ''
SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX)
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
package_data={'mmdet.ops': ['*/*.so']}, package_data={'mmdet.ops': ['*/*.so']},
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent', 'Operating System :: OS Independent',
'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 2.7',
......
...@@ -104,10 +104,19 @@ def main(): ...@@ -104,10 +104,19 @@ def main():
print('Starting evaluate {}'.format(' and '.join(eval_types))) print('Starting evaluate {}'.format(' and '.join(eval_types)))
if eval_types == ['proposal_fast']: if eval_types == ['proposal_fast']:
result_file = args.out result_file = args.out
coco_eval(result_file, eval_types, dataset.coco)
else: else:
result_file = args.out + '.json' if not isinstance(outputs[0], dict):
results2json(dataset, outputs, result_file) result_file = args.out + '.json'
coco_eval(result_file, eval_types, dataset.coco) results2json(dataset, outputs, result_file)
coco_eval(result_file, eval_types, dataset.coco)
else:
for name in outputs[0]:
print('\nEvaluating {}'.format(name))
outputs_ = [out[name] for out in outputs]
result_file = args.out + '.{}.json'.format(name)
results2json(dataset, outputs_, result_file)
coco_eval(result_file, eval_types, dataset.coco)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,9 +5,9 @@ sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmdet') ...@@ -5,9 +5,9 @@ sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmdet')
import argparse import argparse
from mmcv import Config from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__ from mmdet import __version__
from mmdet.datasets import get_dataset
from mmdet.apis import (train_detector, init_dist, get_root_logger, from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed) set_random_seed)
from mmdet.models import build_detector from mmdet.models import build_detector
...@@ -74,13 +74,7 @@ def main(): ...@@ -74,13 +74,7 @@ def main():
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
import torch.distributed as dist train_dataset = get_dataset(cfg.data.train)
if dist.get_rank() == 0:
with open('/mnt/lustre/pangjiangmiao/r50_32x4d_mmdet.txt', 'w') as f:
for k in model.state_dict().keys():
if 'num_batches_tracked' in k: continue
f.writelines('{}\n'.format(k))
train_dataset = obj_from_dict(cfg.data.train, datasets)
train_detector( train_detector(
model, model,
train_dataset, train_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment