Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
7eb02d29
Commit
7eb02d29
authored
Oct 08, 2018
by
Kai Chen
Browse files
Merge branch 'dev' into single-stage
parents
20e75c22
01a03aab
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
71 deletions
+15
-71
mmdet/datasets/utils/data_container.py
mmdet/datasets/utils/data_container.py
+0
-58
tools/dist_train.sh
tools/dist_train.sh
+1
-1
tools/test.py
tools/test.py
+5
-5
tools/train.py
tools/train.py
+9
-7
No files found.
mmdet/datasets/utils/data_container.py
deleted
100644 → 0
View file @
20e75c22
import
functools
import
torch
def
assert_tensor_type
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
isinstance
(
args
[
0
].
data
,
torch
.
Tensor
):
raise
AttributeError
(
'{} has no attribute {} for type {}'
.
format
(
args
[
0
].
__class__
.
__name__
,
func
.
__name__
,
args
[
0
].
datatype
))
return
func
(
*
args
,
**
kwargs
)
return
wrapper
class
DataContainer
(
object
):
def
__init__
(
self
,
data
,
stack
=
False
,
padding_value
=
0
,
cpu_only
=
False
):
self
.
_data
=
data
self
.
_cpu_only
=
cpu_only
self
.
_stack
=
stack
self
.
_padding_value
=
padding_value
def
__repr__
(
self
):
return
'{}({})'
.
format
(
self
.
__class__
.
__name__
,
repr
(
self
.
data
))
@
property
def
data
(
self
):
return
self
.
_data
@
property
def
datatype
(
self
):
if
isinstance
(
self
.
data
,
torch
.
Tensor
):
return
self
.
data
.
type
()
else
:
return
type
(
self
.
data
)
@
property
def
cpu_only
(
self
):
return
self
.
_cpu_only
@
property
def
stack
(
self
):
return
self
.
_stack
@
property
def
padding_value
(
self
):
return
self
.
_padding_value
@
assert_tensor_type
def
size
(
self
,
*
args
,
**
kwargs
):
return
self
.
data
.
size
(
*
args
,
**
kwargs
)
@
assert_tensor_type
def
dim
(
self
):
return
self
.
data
.
dim
()
tools/dist_train.sh
View file @
7eb02d29
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
PYTHON
=
${
PYTHON
:-
"python"
}
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
${
@
:3
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
$(
dirname
"
$0
"
)
/
train.py
$1
--launcher
pytorch
${
@
:3
}
tools/test.py
View file @
7eb02d29
...
@@ -3,9 +3,10 @@ import argparse
...
@@ -3,9 +3,10 @@ import argparse
import
torch
import
torch
import
mmcv
import
mmcv
from
mmcv.runner
import
load_checkpoint
,
parallel_test
,
obj_from_dict
from
mmcv.runner
import
load_checkpoint
,
parallel_test
,
obj_from_dict
from
mmcv.parallel
import
scatter
,
MMDataParallel
from
mmdet
import
datasets
from
mmdet
import
datasets
from
mmdet.core
import
scatter
,
MMDataParallel
,
results2json
,
coco_eval
from
mmdet.core
import
results2json
,
coco_eval
from
mmdet.datasets
import
collate
,
build_dataloader
from
mmdet.datasets
import
collate
,
build_dataloader
from
mmdet.models
import
build_detector
,
detectors
from
mmdet.models
import
build_detector
,
detectors
...
@@ -44,17 +45,16 @@ def parse_args():
...
@@ -44,17 +45,16 @@ def parse_args():
'--eval'
,
'--eval'
,
type
=
str
,
type
=
str
,
nargs
=
'+'
,
nargs
=
'+'
,
choices
=
[
'proposal'
,
'bbox'
,
'segm'
,
'keypoints'
],
choices
=
[
'proposal'
,
'proposal_fast'
,
'bbox'
,
'segm'
,
'keypoints'
],
help
=
'eval types'
)
help
=
'eval types'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
args
=
parse_args
()
def
main
():
def
main
():
args
=
parse_args
()
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
cfg
.
model
.
pretrained
=
None
cfg
.
model
.
pretrained
=
None
cfg
.
data
.
test
.
test_mode
=
True
cfg
.
data
.
test
.
test_mode
=
True
...
...
tools/train.py
View file @
7eb02d29
...
@@ -2,17 +2,18 @@ from __future__ import division
...
@@ -2,17 +2,18 @@ from __future__ import division
import
argparse
import
argparse
import
logging
import
logging
import
random
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.runner
import
Runner
,
obj_from_dict
from
mmcv.runner
import
Runner
,
obj_from_dict
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet
import
datasets
,
__version__
from
mmdet
import
datasets
,
__version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
DistSamplerSeedHook
,
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
CocoDistEvalRecallHook
,
MMDataParallel
,
MMDistributedDataParallel
,
CocoDistEvalmAPHook
)
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
build_detector
,
RPN
from
mmdet.models
import
build_detector
,
RPN
...
@@ -55,6 +56,7 @@ def get_logger(log_level):
...
@@ -55,6 +56,7 @@ def get_logger(log_level):
def
set_random_seed
(
seed
):
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
...
@@ -89,7 +91,7 @@ def main():
...
@@ -89,7 +91,7 @@ def main():
if
args
.
work_dir
is
not
None
:
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
work_dir
=
args
.
work_dir
cfg
.
gpus
=
args
.
gpus
cfg
.
gpus
=
args
.
gpus
#
add
mmdet version
to
checkpoint as meta data
#
save
mmdet version
in
checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
...
@@ -103,13 +105,13 @@ def main():
...
@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
dist
=
False
dist
=
False
logger
.
info
(
'
Disabled
distributed training.'
)
logger
.
info
(
'
Non-
distributed training.'
)
else
:
else
:
dist
=
True
dist
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'
Enabled d
istributed training.'
)
logger
.
info
(
'
D
istributed training.'
)
# prepare data loaders
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment