Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
76168f9c
Commit
76168f9c
authored
Dec 24, 2018
by
ThangVu
Browse files
resolve conflict GN-dev with master
parents
8a086f02
c5d8f002
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
3 deletions
+78
-3
tools/test.py
tools/test.py
+7
-3
tools/train.py
tools/train.py
+9
-0
tools/voc_eval.py
tools/voc_eval.py
+62
-0
No files found.
tools/test.py
View file @
76168f9c
...
...
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
def
single_test
(
model
,
data_loader
,
show
=
False
):
model
.
eval
()
results
=
[]
prog_bar
=
mmcv
.
ProgressBar
(
len
(
data_loader
.
dataset
))
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
results
.
append
(
result
)
if
show
:
model
.
module
.
show_result
(
data
,
result
,
data
_loader
.
dataset
.
img_norm_cfg
)
model
.
module
.
show_result
(
data
,
result
,
dataset
.
img_norm_cfg
,
data
set
.
CLASSES
)
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
for
_
in
range
(
batch_size
):
...
...
@@ -65,6 +66,9 @@ def main():
raise
ValueError
(
'The output file must be a pkl file.'
)
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
cfg
.
model
.
pretrained
=
None
cfg
.
data
.
test
.
test_mode
=
True
...
...
tools/train.py
View file @
76168f9c
...
...
@@ -8,12 +8,15 @@ from mmdet.datasets import get_dataset
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
set_random_seed
)
from
mmdet.models
import
build_detector
import
torch
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work_dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume_from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--validate'
,
action
=
'store_true'
,
...
...
@@ -40,9 +43,14 @@ def main():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# update configs according to CLI args
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
cfg
.
gpus
=
args
.
gpus
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version in checkpoints as meta data
...
...
@@ -67,6 +75,7 @@ def main():
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
train_dataset
=
get_dataset
(
cfg
.
data
.
train
)
train_detector
(
model
,
...
...
tools/voc_eval.py
0 → 100644
View file @
76168f9c
from
argparse
import
ArgumentParser
import
mmcv
import
numpy
as
np
from
mmdet
import
datasets
from
mmdet.core
import
eval_map
def
voc_eval
(
result_file
,
dataset
,
iou_thr
=
0.5
):
det_results
=
mmcv
.
load
(
result_file
)
gt_bboxes
=
[]
gt_labels
=
[]
gt_ignore
=
[]
for
i
in
range
(
len
(
dataset
)):
ann
=
dataset
.
get_ann_info
(
i
)
bboxes
=
ann
[
'bboxes'
]
labels
=
ann
[
'labels'
]
if
'bboxes_ignore'
in
ann
:
ignore
=
np
.
concatenate
([
np
.
zeros
(
bboxes
.
shape
[
0
],
dtype
=
np
.
bool
),
np
.
ones
(
ann
[
'bboxes_ignore'
].
shape
[
0
],
dtype
=
np
.
bool
)
])
gt_ignore
.
append
(
ignore
)
bboxes
=
np
.
vstack
([
bboxes
,
ann
[
'bboxes_ignore'
]])
labels
=
np
.
concatenate
([
labels
,
ann
[
'labels_ignore'
]])
gt_bboxes
.
append
(
bboxes
)
gt_labels
.
append
(
labels
)
if
not
gt_ignore
:
gt_ignore
=
gt_ignore
if
hasattr
(
dataset
,
'year'
)
and
dataset
.
year
==
2007
:
dataset_name
=
'voc07'
else
:
dataset_name
=
dataset
.
CLASSES
eval_map
(
det_results
,
gt_bboxes
,
gt_labels
,
gt_ignore
=
gt_ignore
,
scale_ranges
=
None
,
iou_thr
=
iou_thr
,
dataset
=
dataset_name
,
print_summary
=
True
)
def
main
():
parser
=
ArgumentParser
(
description
=
'VOC Evaluation'
)
parser
.
add_argument
(
'result'
,
help
=
'result file path'
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
parser
.
add_argument
(
'--iou-thr'
,
type
=
float
,
default
=
0.5
,
help
=
'IoU threshold for evaluation'
)
args
=
parser
.
parse_args
()
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
test_dataset
=
mmcv
.
runner
.
obj_from_dict
(
cfg
.
data
.
test
,
datasets
)
voc_eval
(
args
.
result
,
test_dataset
,
args
.
iou_thr
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment