Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
7eb02d29
Commit
7eb02d29
authored
Oct 08, 2018
by
Kai Chen
Browse files
Merge branch 'dev' into single-stage
parents
20e75c22
01a03aab
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
71 deletions
+15
-71
mmdet/datasets/utils/data_container.py
mmdet/datasets/utils/data_container.py
+0
-58
tools/dist_train.sh
tools/dist_train.sh
+1
-1
tools/test.py
tools/test.py
+5
-5
tools/train.py
tools/train.py
+9
-7
No files found.
mmdet/datasets/utils/data_container.py
deleted
100644 → 0
View file @
20e75c22
import
functools
import
torch
def
assert_tensor_type
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
isinstance
(
args
[
0
].
data
,
torch
.
Tensor
):
raise
AttributeError
(
'{} has no attribute {} for type {}'
.
format
(
args
[
0
].
__class__
.
__name__
,
func
.
__name__
,
args
[
0
].
datatype
))
return
func
(
*
args
,
**
kwargs
)
return
wrapper
class
DataContainer
(
object
):
def
__init__
(
self
,
data
,
stack
=
False
,
padding_value
=
0
,
cpu_only
=
False
):
self
.
_data
=
data
self
.
_cpu_only
=
cpu_only
self
.
_stack
=
stack
self
.
_padding_value
=
padding_value
def
__repr__
(
self
):
return
'{}({})'
.
format
(
self
.
__class__
.
__name__
,
repr
(
self
.
data
))
@
property
def
data
(
self
):
return
self
.
_data
@
property
def
datatype
(
self
):
if
isinstance
(
self
.
data
,
torch
.
Tensor
):
return
self
.
data
.
type
()
else
:
return
type
(
self
.
data
)
@
property
def
cpu_only
(
self
):
return
self
.
_cpu_only
@
property
def
stack
(
self
):
return
self
.
_stack
@
property
def
padding_value
(
self
):
return
self
.
_padding_value
@
assert_tensor_type
def
size
(
self
,
*
args
,
**
kwargs
):
return
self
.
data
.
size
(
*
args
,
**
kwargs
)
@
assert_tensor_type
def
dim
(
self
):
return
self
.
data
.
dim
()
tools/dist_train.sh
View file @
7eb02d29
...
...
@@ -2,4 +2,4 @@
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
${
@
:3
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
$(
dirname
"
$0
"
)
/
train.py
$1
--launcher
pytorch
${
@
:3
}
tools/test.py
View file @
7eb02d29
...
...
@@ -3,9 +3,10 @@ import argparse
import
torch
import
mmcv
from
mmcv.runner
import
load_checkpoint
,
parallel_test
,
obj_from_dict
from
mmcv.parallel
import
scatter
,
MMDataParallel
from
mmdet
import
datasets
from
mmdet.core
import
scatter
,
MMDataParallel
,
results2json
,
coco_eval
from
mmdet.core
import
results2json
,
coco_eval
from
mmdet.datasets
import
collate
,
build_dataloader
from
mmdet.models
import
build_detector
,
detectors
...
...
@@ -44,17 +45,16 @@ def parse_args():
'--eval'
,
type
=
str
,
nargs
=
'+'
,
choices
=
[
'proposal'
,
'bbox'
,
'segm'
,
'keypoints'
],
choices
=
[
'proposal'
,
'proposal_fast'
,
'bbox'
,
'segm'
,
'keypoints'
],
help
=
'eval types'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
args
=
parser
.
parse_args
()
return
args
args
=
parse_args
()
def
main
():
args
=
parse_args
()
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
cfg
.
model
.
pretrained
=
None
cfg
.
data
.
test
.
test_mode
=
True
...
...
tools/train.py
View file @
7eb02d29
...
...
@@ -2,17 +2,18 @@ from __future__ import division
import
argparse
import
logging
import
random
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
mmcv
import
Config
from
mmcv.runner
import
Runner
,
obj_from_dict
from
mmcv.runner
import
Runner
,
obj_from_dict
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet
import
datasets
,
__version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
DistSamplerSeedHook
,
MMDataParallel
,
MMDistributedDataParallel
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
build_detector
,
RPN
...
...
@@ -55,6 +56,7 @@ def get_logger(log_level):
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
...
...
@@ -89,7 +91,7 @@ def main():
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
gpus
=
args
.
gpus
#
add
mmdet version
to
checkpoint as meta data
#
save
mmdet version
in
checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
...
...
@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'
Disabled
distributed training.'
)
logger
.
info
(
'
Non-
distributed training.'
)
else
:
dist
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'
Enabled d
istributed training.'
)
logger
.
info
(
'
D
istributed training.'
)
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment