Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
4c1da636
Commit
4c1da636
authored
Oct 11, 2018
by
myownskyW7
Browse files
add high level api
parent
d13997c3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
188 additions
and
115 deletions
+188
-115
mmdet/api/__init__.py
mmdet/api/__init__.py
+4
-0
mmdet/api/inference.py
mmdet/api/inference.py
+54
-0
mmdet/api/train.py
mmdet/api/train.py
+120
-0
tools/train.py
tools/train.py
+10
-115
No files found.
mmdet/api/__init__.py
0 → 100644
View file @
4c1da636
from
.train
import
train_detector
from
.inference
import
inference_detector
__all__
=
[
'train_detector'
,
'inference_detector'
]
mmdet/api/inference.py
0 → 100644
View file @
4c1da636
import
mmcv
import
numpy
as
np
import
torch
from
mmdet.datasets
import
to_tensor
from
mmdet.datasets.transforms
import
ImageTransform
from
mmdet.core
import
get_classes
def
_prepare_data
(
img
,
img_transform
,
cfg
,
device
):
ori_shape
=
img
.
shape
img
,
img_shape
,
pad_shape
,
scale_factor
=
img_transform
(
img
,
scale
=
cfg
.
data
.
test
.
img_scale
)
img
=
to_tensor
(
img
).
to
(
device
).
unsqueeze
(
0
)
img_meta
=
[
dict
(
ori_shape
=
ori_shape
,
img_shape
=
img_shape
,
pad_shape
=
pad_shape
,
scale_factor
=
scale_factor
,
flip
=
False
)
]
return
dict
(
img
=
[
img
],
img_meta
=
[
img_meta
])
def
inference_detector
(
model
,
imgs
,
cfg
,
device
=
'cuda:0'
):
imgs
=
imgs
if
isinstance
(
imgs
,
list
)
else
[
imgs
]
img_transform
=
ImageTransform
(
**
cfg
.
img_norm_cfg
,
size_divisor
=
cfg
.
data
.
test
.
size_divisor
)
model
=
model
.
to
(
device
)
model
.
eval
()
for
img
in
imgs
:
img
=
mmcv
.
imread
(
img
)
data
=
_prepare_data
(
img
,
img_transform
,
cfg
,
device
)
with
torch
.
no_grad
():
result
=
model
(
**
data
,
return_loss
=
False
,
rescale
=
True
)
yield
result
def
show_result
(
img
,
result
,
dataset
=
'coco'
,
score_thr
=
0.3
):
class_names
=
get_classes
(
dataset
)
labels
=
[
np
.
full
(
bbox
.
shape
[
0
],
i
,
dtype
=
np
.
int32
)
for
i
,
bbox
in
enumerate
(
result
)
]
labels
=
np
.
concatenate
(
labels
)
bboxes
=
np
.
vstack
(
result
)
mmcv
.
imshow_det_bboxes
(
img
.
copy
(),
bboxes
,
labels
,
class_names
=
class_names
,
score_thr
=
score_thr
)
mmdet/api/train.py
0 → 100644
View file @
4c1da636
from
__future__
import
division
import
logging
import
random
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet
import
__version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
RPN
def
parse_losses
(
losses
):
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
'{} is not a tensor or list of tensors'
.
format
(
loss_name
))
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
name
in
log_vars
:
log_vars
[
name
]
=
log_vars
[
name
].
item
()
return
loss
,
log_vars
def
batch_processor
(
model
,
data
,
train_mode
):
losses
=
model
(
**
data
)
loss
,
log_vars
=
parse_losses
(
losses
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
len
(
data
[
'img'
].
data
))
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
def
train_detector
(
model
,
dataset
,
cfg
):
# save mmdet version in checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
logger
=
get_logger
(
cfg
.
log_level
)
# set random seed if specified
if
cfg
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {}'
.
format
(
cfg
.
seed
))
set_random_seed
(
cfg
.
seed
)
# init distributed environment if necessary
if
cfg
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'Non-distributed training.'
)
else
:
dist
=
True
init_dist
(
cfg
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Distributed training.'
)
# prepare data loaders
data_loaders
=
[
build_dataloader
(
dataset
,
cfg
.
data
.
imgs_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
gpus
,
dist
)
]
# put model on gpus
if
dist
:
model
=
MMDistributedDataParallel
(
model
.
cuda
())
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
range
(
cfg
.
gpus
)).
cuda
()
# build runner
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
cfg
.
log_level
)
# register hooks
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
if
dist
else
cfg
.
optimizer_config
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
cfg
.
validate
:
if
isinstance
(
model
.
module
,
RPN
):
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
,
cfg
.
total_epochs
)
\ No newline at end of file
tools/train.py
View file @
4c1da636
from
__future__
import
division
import
argparse
import
logging
import
random
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
mmcv
import
Config
from
mmcv.runner
import
Runner
,
obj_from_dict
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet
import
datasets
,
__version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
build_detector
,
RPN
def
parse_losses
(
losses
):
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
'{} is not a tensor or list of tensors'
.
format
(
loss_name
))
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
name
in
log_vars
:
log_vars
[
name
]
=
log_vars
[
name
].
item
()
return
loss
,
log_vars
def
batch_processor
(
model
,
data
,
train_mode
):
losses
=
model
(
**
data
)
loss
,
log_vars
=
parse_losses
(
losses
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
len
(
data
[
'img'
].
data
))
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
from
mmcv.runner
import
obj_from_dict
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
from
mmdet
import
datasets
from
mmdet.api
import
train_detector
from
mmdet.models
import
build_detector
def
parse_args
():
...
...
@@ -86,71 +33,19 @@ def parse_args():
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
validate
=
args
.
validate
cfg
.
gpus
=
args
.
gpus
# save mmdet version in checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
logger
=
get_logger
(
cfg
.
log_level
)
# set random seed if specified
if
args
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {}'
.
format
(
args
.
seed
))
set_random_seed
(
args
.
seed
)
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'Non-distributed training.'
)
else
:
dist
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Distributed training.'
)
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
data_loaders
=
[
build_dataloader
(
train_dataset
,
cfg
.
data
.
imgs_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
gpus
,
dist
)
]
cfg
.
seed
=
args
.
seed
cfg
.
launcher
=
args
.
launcher
cfg
.
local_rank
=
args
.
local_rank
# build model
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
if
dist
:
model
=
MMDistributedDataParallel
(
model
.
cuda
())
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
range
(
cfg
.
gpus
)).
cuda
()
# build runner
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
cfg
.
log_level
)
# register hooks
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
if
dist
else
cfg
.
optimizer_config
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
args
.
validate
:
if
isinstance
(
model
.
module
,
RPN
):
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
,
cfg
.
total_epochs
)
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
train_detector
(
model
,
train_dataset
,
cfg
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment