Commit 7eb02d29 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'dev' into single-stage

parents 20e75c22 01a03aab
......@@ -79,7 +79,7 @@ test_cfg = dict(
rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = '../data/coco/'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
......@@ -140,7 +140,7 @@ log_config = dict(
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='gloo')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
load_from = None
......
......@@ -92,7 +92,7 @@ test_cfg = dict(
score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = '../data/coco/'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
......@@ -153,7 +153,7 @@ log_config = dict(
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='gloo')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
load_from = None
......
......@@ -50,7 +50,7 @@ test_cfg = dict(
min_bbox_size=0))
# dataset settings
dataset_type = 'CocoDataset'
data_root = '../data/coco/'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
......@@ -110,7 +110,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='gloo')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None
......
from .anchor import * # noqa: F401, F403
from .bbox_ops import * # noqa: F401, F403
from .mask_ops import * # noqa: F401, F403
from .losses import * # noqa: F401, F403
from .eval import * # noqa: F401, F403
from .parallel import * # noqa: F401, F403
from .bbox import * # noqa: F401, F403
from .mask import * # noqa: F401, F403
from .loss import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
import torch
from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling
from ..bbox import bbox_assign, bbox2delta, bbox_sampling
from ..utils import multi_apply
......
......@@ -5,6 +5,11 @@ from .geometry import bbox_overlaps
def random_choice(gallery, num):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use numpy
to randperm the indices.
"""
assert len(gallery) >= num
if isinstance(gallery, list):
gallery = np.array(gallery)
......@@ -12,9 +17,7 @@ def random_choice(gallery, num):
np.random.shuffle(cands)
rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long()
if gallery.is_cuda:
rand_inds = rand_inds.cuda(gallery.get_device())
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds]
......
......@@ -7,13 +7,12 @@ import mmcv
import numpy as np
import torch
from mmcv.runner import Hook, obj_from_dict
from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset
from .coco_utils import results2json, fast_eval_recall
from ..parallel import scatter
from mmdet import datasets
from mmdet.datasets.loader import collate
class DistEvalHook(Hook):
......
from .utils import split_combined_polys
from .mask_target import mask_target
__all__ = ['split_combined_polys', 'mask_target']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment