Commit a9f02204 authored by Kai Chen's avatar Kai Chen
Browse files

rename bbox_ops/mask_ops to bbox/mask

parent 5686a375
from .anchor import * # noqa: F401, F403 from .anchor import * # noqa: F401, F403
from .bbox_ops import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403
from .mask_ops import * # noqa: F401, F403 from .mask import * # noqa: F401, F403
from .losses import * # noqa: F401, F403 from .losses import * # noqa: F401, F403
from .eval import * # noqa: F401, F403 from .eval import * # noqa: F401, F403
from .parallel import * # noqa: F401, F403 from .parallel import * # noqa: F401, F403
......
import torch import torch
from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling from ..bbox import bbox_assign, bbox2delta, bbox_sampling
from ..utils import multi_apply from ..utils import multi_apply
......
...@@ -5,6 +5,11 @@ from .geometry import bbox_overlaps ...@@ -5,6 +5,11 @@ from .geometry import bbox_overlaps
def random_choice(gallery, num): def random_choice(gallery, num):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use numpy
to randperm the indices.
"""
assert len(gallery) >= num assert len(gallery) >= num
if isinstance(gallery, list): if isinstance(gallery, list):
gallery = np.array(gallery) gallery = np.array(gallery)
...@@ -12,9 +17,7 @@ def random_choice(gallery, num): ...@@ -12,9 +17,7 @@ def random_choice(gallery, num):
np.random.shuffle(cands) np.random.shuffle(cands)
rand_inds = cands[:num] rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray): if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long() rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
if gallery.is_cuda:
rand_inds = rand_inds.cuda(gallery.get_device())
return gallery[rand_inds] return gallery[rand_inds]
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import numpy as np import numpy as np
from mmdet.ops import nms from mmdet.ops import nms
from ..bbox_ops import bbox_mapping_back from ..bbox import bbox_mapping_back
def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg): def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg):
......
...@@ -2,7 +2,7 @@ import mmcv ...@@ -2,7 +2,7 @@ import mmcv
import numpy as np import numpy as np
import torch import torch
from mmdet.core.mask_ops import segms from mmdet.core.mask import segms
__all__ = [ __all__ = [
'ImageTransform', 'BboxTransform', 'PolyMaskTransform', 'Numpy2Tensor' 'ImageTransform', 'BboxTransform', 'PolyMaskTransform', 'Numpy2Tensor'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment