Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
54b54d88
Unverified
Commit
54b54d88
authored
Oct 12, 2018
by
Kai Chen
Committed by
GitHub
Oct 12, 2018
Browse files
Merge pull request #19 from hellock/dev
Update inference APIs
parents
abc440fc
459d5ebc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
13 deletions
+25
-13
mmdet/apis/__init__.py
mmdet/apis/__init__.py
+2
-2
mmdet/apis/inference.py
mmdet/apis/inference.py
+19
-8
mmdet/models/builder.py
mmdet/models/builder.py
+2
-1
mmdet/models/rpn_heads/rpn_head.py
mmdet/models/rpn_heads/rpn_head.py
+2
-2
No files found.
mmdet/apis/__init__.py
View file @
54b54d88
from
.env
import
init_dist
,
get_root_logger
,
set_random_seed
from
.env
import
init_dist
,
get_root_logger
,
set_random_seed
from
.train
import
train_detector
from
.train
import
train_detector
from
.inference
import
inference_detector
from
.inference
import
inference_detector
,
show_result
__all__
=
[
__all__
=
[
'init_dist'
,
'get_root_logger'
,
'set_random_seed'
,
'train_detector'
,
'init_dist'
,
'get_root_logger'
,
'set_random_seed'
,
'train_detector'
,
'inference_detector'
'inference_detector'
,
'show_result'
]
]
mmdet/apis/inference.py
View file @
54b54d88
...
@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
...
@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
return
dict
(
img
=
[
img
],
img_meta
=
[
img_meta
])
return
dict
(
img
=
[
img
],
img_meta
=
[
img_meta
])
def
inference_detector
(
model
,
imgs
,
cfg
,
device
=
'cuda:0'
):
def
_inference_single
(
model
,
img
,
img_transform
,
cfg
,
device
):
img
=
mmcv
.
imread
(
img
)
data
=
_prepare_data
(
img
,
img_transform
,
cfg
,
device
)
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
return
result
def
_inference_generator
(
model
,
imgs
,
img_transform
,
cfg
,
device
):
for
img
in
imgs
:
yield
_inference_single
(
model
,
img
,
img_transform
,
cfg
,
device
)
imgs
=
imgs
if
isinstance
(
imgs
,
list
)
else
[
imgs
]
def
inference_detector
(
model
,
imgs
,
cfg
,
device
=
'cuda:0'
):
img_transform
=
ImageTransform
(
img_transform
=
ImageTransform
(
size_divisor
=
cfg
.
data
.
test
.
size_divisor
,
**
cfg
.
img_norm_cfg
)
size_divisor
=
cfg
.
data
.
test
.
size_divisor
,
**
cfg
.
img_norm_cfg
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
model
.
eval
()
model
.
eval
()
for
img
in
imgs
:
img
=
mmcv
.
imread
(
img
)
if
not
isinstance
(
imgs
,
list
):
data
=
_prepare_data
(
img
,
img_transform
,
cfg
,
device
)
return
_inference_single
(
model
,
imgs
,
img_transform
,
cfg
,
device
)
with
torch
.
no_grad
():
else
:
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
return
_inference_generator
(
model
,
imgs
,
img_transform
,
cfg
,
device
)
yield
result
def
show_result
(
img
,
result
,
dataset
=
'coco'
,
score_thr
=
0.3
):
def
show_result
(
img
,
result
,
dataset
=
'coco'
,
score_thr
=
0.3
):
...
@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
...
@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
]
]
labels
=
np
.
concatenate
(
labels
)
labels
=
np
.
concatenate
(
labels
)
bboxes
=
np
.
vstack
(
result
)
bboxes
=
np
.
vstack
(
result
)
img
=
mmcv
.
imread
(
img
)
mmcv
.
imshow_det_bboxes
(
mmcv
.
imshow_det_bboxes
(
img
.
copy
(),
img
.
copy
(),
bboxes
,
bboxes
,
...
...
mmdet/models/builder.py
View file @
54b54d88
...
@@ -2,7 +2,7 @@ from mmcv.runner import obj_from_dict
...
@@ -2,7 +2,7 @@ from mmcv.runner import obj_from_dict
from
torch
import
nn
from
torch
import
nn
from
.
import
(
backbones
,
necks
,
roi_extractors
,
rpn_heads
,
bbox_heads
,
from
.
import
(
backbones
,
necks
,
roi_extractors
,
rpn_heads
,
bbox_heads
,
mask_heads
,
detectors
)
mask_heads
)
__all__
=
[
__all__
=
[
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
...
@@ -48,4 +48,5 @@ def build_mask_head(cfg):
...
@@ -48,4 +48,5 @@ def build_mask_head(cfg):
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
from
.
import
detectors
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
mmdet/models/rpn_heads/rpn_head.py
View file @
54b54d88
...
@@ -48,8 +48,8 @@ class RPNHead(nn.Module):
...
@@ -48,8 +48,8 @@ class RPNHead(nn.Module):
self
.
anchor_scales
=
anchor_scales
self
.
anchor_scales
=
anchor_scales
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_strides
=
anchor_strides
self
.
anchor_strides
=
anchor_strides
self
.
anchor_base_sizes
=
anchor_strides
.
copy
(
self
.
anchor_base_sizes
=
list
(
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
self
.
target_means
=
target_means
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
target_stds
=
target_stds
self
.
use_sigmoid_cls
=
use_sigmoid_cls
self
.
use_sigmoid_cls
=
use_sigmoid_cls
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment