Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
c6fde230
Commit
c6fde230
authored
Dec 11, 2018
by
pangjm
Browse files
Merge branch 'master' of github.com:open-mmlab/mmdetection
Conflicts: tools/train.py
parents
e74519bb
826a5613
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
38 deletions
+56
-38
mmdet/ops/nms/nms_wrapper.py
mmdet/ops/nms/nms_wrapper.py
+39
-24
setup.py
setup.py
+2
-2
tools/test.py
tools/test.py
+12
-3
tools/train.py
tools/train.py
+3
-9
No files found.
mmdet/ops/nms/nms_wrapper.py
View file @
c6fde230
...
@@ -6,43 +6,58 @@ from .cpu_nms import cpu_nms
...
@@ -6,43 +6,58 @@ from .cpu_nms import cpu_nms
from
.cpu_soft_nms
import
cpu_soft_nms
from
.cpu_soft_nms
import
cpu_soft_nms
def
nms
(
dets
,
thr
esh
,
device_id
=
None
):
def
nms
(
dets
,
iou_
thr
,
device_id
=
None
):
"""Dispatch to either CPU or GPU NMS implementations."""
"""Dispatch to either CPU or GPU NMS implementations."""
tensor_device
=
None
if
isinstance
(
dets
,
torch
.
Tensor
):
if
isinstance
(
dets
,
torch
.
Tensor
):
tensor
_device
=
dets
.
devic
e
is_
tensor
=
Tru
e
if
dets
.
is_cuda
:
if
dets
.
is_cuda
:
device_id
=
dets
.
get_device
()
device_id
=
dets
.
get_device
()
dets
=
dets
.
detach
().
cpu
().
numpy
()
dets_np
=
dets
.
detach
().
cpu
().
numpy
()
assert
isinstance
(
dets
,
np
.
ndarray
)
elif
isinstance
(
dets
,
np
.
ndarray
):
is_tensor
=
False
dets_np
=
dets
else
:
raise
TypeError
(
'dets must be either a Tensor or numpy array, but got {}'
.
format
(
type
(
dets
)))
if
dets
.
shape
[
0
]
==
0
:
if
dets
_np
.
shape
[
0
]
==
0
:
inds
=
[]
inds
=
[]
else
:
else
:
inds
=
(
gpu_nms
(
dets
,
thresh
,
device_id
=
device_id
)
inds
=
(
gpu_nms
(
dets
_np
,
iou_thr
,
device_id
=
device_id
)
if
device_id
is
not
None
else
cpu_nms
(
dets
,
thresh
))
if
device_id
is
not
None
else
cpu_nms
(
dets
_np
,
iou_thr
))
if
tensor
_device
:
if
is_
tensor
:
return
torch
.
Tensor
(
inds
).
long
().
to
(
tensor_device
)
inds
=
dets
.
new_tensor
(
inds
,
dtype
=
torch
.
long
)
else
:
else
:
return
np
.
array
(
inds
,
dtype
=
np
.
int
)
inds
=
np
.
array
(
inds
,
dtype
=
np
.
int64
)
return
dets
[
inds
,
:],
inds
def
soft_nms
(
dets
,
Nt
=
0.3
,
method
=
1
,
sigma
=
0.5
,
min_score
=
0
):
def
soft_nms
(
dets
,
iou_thr
,
method
=
'linear'
,
sigma
=
0.5
,
min_score
=
1e-3
):
if
isinstance
(
dets
,
torch
.
Tensor
):
if
isinstance
(
dets
,
torch
.
Tensor
):
_dets
=
dets
.
detach
().
cpu
().
numpy
()
is_tensor
=
True
dets_np
=
dets
.
detach
().
cpu
().
numpy
()
elif
isinstance
(
dets
,
np
.
ndarray
):
is_tensor
=
False
dets_np
=
dets
else
:
else
:
_dets
=
dets
.
copy
()
raise
TypeError
(
assert
isinstance
(
_dets
,
np
.
ndarray
)
'dets must be either a Tensor or numpy array, but got {}'
.
format
(
type
(
dets
)))
method_codes
=
{
'linear'
:
1
,
'gaussian'
:
2
}
if
method
not
in
method_codes
:
raise
ValueError
(
'Invalid method for SoftNMS: {}'
.
format
(
method
))
new_dets
,
inds
=
cpu_soft_nms
(
new_dets
,
inds
=
cpu_soft_nms
(
_dets
,
Nt
=
Nt
,
method
=
method
,
sigma
=
sigma
,
threshold
=
min_score
)
dets_np
,
iou_thr
,
if
isinstance
(
dets
,
torch
.
Tensor
):
method
=
method_codes
[
method
],
return
dets
.
new_tensor
(
sigma
=
sigma
,
inds
,
dtype
=
torch
.
long
),
dets
.
new_tensor
(
new_dets
)
min_score
=
min_score
)
if
is_tensor
:
return
dets
.
new_tensor
(
new_dets
),
dets
.
new_tensor
(
inds
,
dtype
=
torch
.
long
)
else
:
else
:
return
np
.
array
(
return
new_dets
.
astype
(
np
.
float32
),
inds
.
astype
(
np
.
int64
)
inds
,
dtype
=
np
.
int
),
np
.
array
(
new_dets
,
dtype
=
np
.
float32
)
setup.py
View file @
c6fde230
...
@@ -12,7 +12,7 @@ def readme():
...
@@ -12,7 +12,7 @@ def readme():
MAJOR
=
0
MAJOR
=
0
MINOR
=
5
MINOR
=
5
PATCH
=
2
PATCH
=
4
SUFFIX
=
''
SUFFIX
=
''
SHORT_VERSION
=
'{}.{}.{}{}'
.
format
(
MAJOR
,
MINOR
,
PATCH
,
SUFFIX
)
SHORT_VERSION
=
'{}.{}.{}{}'
.
format
(
MAJOR
,
MINOR
,
PATCH
,
SUFFIX
)
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
package_data
=
{
'mmdet.ops'
:
[
'*/*.so'
]},
package_data
=
{
'mmdet.ops'
:
[
'*/*.so'
]},
classifiers
=
[
classifiers
=
[
'Development Status :: 4 - Beta'
,
'Development Status :: 4 - Beta'
,
'License :: OSI Approved ::
GNU General Public License v3 (GPLv3)
'
,
'License :: OSI Approved ::
Apache Software License
'
,
'Operating System :: OS Independent'
,
'Operating System :: OS Independent'
,
'Programming Language :: Python :: 2'
,
'Programming Language :: Python :: 2'
,
'Programming Language :: Python :: 2.7'
,
'Programming Language :: Python :: 2.7'
,
...
...
tools/test.py
View file @
c6fde230
...
@@ -104,10 +104,19 @@ def main():
...
@@ -104,10 +104,19 @@ def main():
print
(
'Starting evaluate {}'
.
format
(
' and '
.
join
(
eval_types
)))
print
(
'Starting evaluate {}'
.
format
(
' and '
.
join
(
eval_types
)))
if
eval_types
==
[
'proposal_fast'
]:
if
eval_types
==
[
'proposal_fast'
]:
result_file
=
args
.
out
result_file
=
args
.
out
coco_eval
(
result_file
,
eval_types
,
dataset
.
coco
)
else
:
else
:
result_file
=
args
.
out
+
'.json'
if
not
isinstance
(
outputs
[
0
],
dict
):
results2json
(
dataset
,
outputs
,
result_file
)
result_file
=
args
.
out
+
'.json'
coco_eval
(
result_file
,
eval_types
,
dataset
.
coco
)
results2json
(
dataset
,
outputs
,
result_file
)
coco_eval
(
result_file
,
eval_types
,
dataset
.
coco
)
else
:
for
name
in
outputs
[
0
]:
print
(
'
\n
Evaluating {}'
.
format
(
name
))
outputs_
=
[
out
[
name
]
for
out
in
outputs
]
result_file
=
args
.
out
+
'.{}.json'
.
format
(
name
)
results2json
(
dataset
,
outputs_
,
result_file
)
coco_eval
(
result_file
,
eval_types
,
dataset
.
coco
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tools/train.py
View file @
c6fde230
...
@@ -5,9 +5,9 @@ sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmdet')
...
@@ -5,9 +5,9 @@ sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmdet')
import
argparse
import
argparse
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.runner
import
obj_from_dict
from
mmdet
import
datasets
,
__version__
from
mmdet
import
__version__
from
mmdet.datasets
import
get_dataset
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
set_random_seed
)
set_random_seed
)
from
mmdet.models
import
build_detector
from
mmdet.models
import
build_detector
...
@@ -74,13 +74,7 @@ def main():
...
@@ -74,13 +74,7 @@ def main():
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
import
torch.distributed
as
dist
train_dataset
=
get_dataset
(
cfg
.
data
.
train
)
if
dist
.
get_rank
()
==
0
:
with
open
(
'/mnt/lustre/pangjiangmiao/r50_32x4d_mmdet.txt'
,
'w'
)
as
f
:
for
k
in
model
.
state_dict
().
keys
():
if
'num_batches_tracked'
in
k
:
continue
f
.
writelines
(
'{}
\n
'
.
format
(
k
))
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
train_detector
(
train_detector
(
model
,
model
,
train_dataset
,
train_dataset
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment