Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
4c1da636
Commit
4c1da636
authored
Oct 11, 2018
by
myownskyW7
Browse files
add high level api
parent
d13997c3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
188 additions
and
115 deletions
+188
-115
mmdet/api/__init__.py
mmdet/api/__init__.py
+4
-0
mmdet/api/inference.py
mmdet/api/inference.py
+54
-0
mmdet/api/train.py
mmdet/api/train.py
+120
-0
tools/train.py
tools/train.py
+10
-115
No files found.
mmdet/api/__init__.py
0 → 100644
View file @
4c1da636
from
.train
import
train_detector
from
.inference
import
inference_detector
__all__
=
[
'train_detector'
,
'inference_detector'
]
mmdet/api/inference.py
0 → 100644
View file @
4c1da636
import
mmcv
import
numpy
as
np
import
torch
from
mmdet.datasets
import
to_tensor
from
mmdet.datasets.transforms
import
ImageTransform
from
mmdet.core
import
get_classes
def
_prepare_data
(
img
,
img_transform
,
cfg
,
device
):
ori_shape
=
img
.
shape
img
,
img_shape
,
pad_shape
,
scale_factor
=
img_transform
(
img
,
scale
=
cfg
.
data
.
test
.
img_scale
)
img
=
to_tensor
(
img
).
to
(
device
).
unsqueeze
(
0
)
img_meta
=
[
dict
(
ori_shape
=
ori_shape
,
img_shape
=
img_shape
,
pad_shape
=
pad_shape
,
scale_factor
=
scale_factor
,
flip
=
False
)
]
return
dict
(
img
=
[
img
],
img_meta
=
[
img_meta
])
def
inference_detector
(
model
,
imgs
,
cfg
,
device
=
'cuda:0'
):
imgs
=
imgs
if
isinstance
(
imgs
,
list
)
else
[
imgs
]
img_transform
=
ImageTransform
(
**
cfg
.
img_norm_cfg
,
size_divisor
=
cfg
.
data
.
test
.
size_divisor
)
model
=
model
.
to
(
device
)
model
.
eval
()
for
img
in
imgs
:
img
=
mmcv
.
imread
(
img
)
data
=
_prepare_data
(
img
,
img_transform
,
cfg
,
device
)
with
torch
.
no_grad
():
result
=
model
(
**
data
,
return_loss
=
False
,
rescale
=
True
)
yield
result
def
show_result
(
img
,
result
,
dataset
=
'coco'
,
score_thr
=
0.3
):
class_names
=
get_classes
(
dataset
)
labels
=
[
np
.
full
(
bbox
.
shape
[
0
],
i
,
dtype
=
np
.
int32
)
for
i
,
bbox
in
enumerate
(
result
)
]
labels
=
np
.
concatenate
(
labels
)
bboxes
=
np
.
vstack
(
result
)
mmcv
.
imshow_det_bboxes
(
img
.
copy
(),
bboxes
,
labels
,
class_names
=
class_names
,
score_thr
=
score_thr
)
mmdet/api/train.py
0 → 100644
View file @
4c1da636
from
__future__
import
division
import
logging
import
random
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet
import
__version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
RPN
def
parse_losses
(
losses
):
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
'{} is not a tensor or list of tensors'
.
format
(
loss_name
))
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
name
in
log_vars
:
log_vars
[
name
]
=
log_vars
[
name
].
item
()
return
loss
,
log_vars
def
batch_processor
(
model
,
data
,
train_mode
):
losses
=
model
(
**
data
)
loss
,
log_vars
=
parse_losses
(
losses
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
len
(
data
[
'img'
].
data
))
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
def
train_detector
(
model
,
dataset
,
cfg
):
# save mmdet version in checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
logger
=
get_logger
(
cfg
.
log_level
)
# set random seed if specified
if
cfg
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {}'
.
format
(
cfg
.
seed
))
set_random_seed
(
cfg
.
seed
)
# init distributed environment if necessary
if
cfg
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'Non-distributed training.'
)
else
:
dist
=
True
init_dist
(
cfg
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Distributed training.'
)
# prepare data loaders
data_loaders
=
[
build_dataloader
(
dataset
,
cfg
.
data
.
imgs_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
gpus
,
dist
)
]
# put model on gpus
if
dist
:
model
=
MMDistributedDataParallel
(
model
.
cuda
())
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
range
(
cfg
.
gpus
)).
cuda
()
# build runner
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
cfg
.
log_level
)
# register hooks
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
if
dist
else
cfg
.
optimizer_config
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
cfg
.
validate
:
if
isinstance
(
model
.
module
,
RPN
):
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
,
cfg
.
total_epochs
)
\ No newline at end of file
tools/train.py
View file @
4c1da636
from
__future__
import
division
from
__future__
import
division
import
argparse
import
argparse
import
logging
import
random
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.runner
import
Runner
,
obj_from_dict
,
DistSamplerSeedHook
from
mmcv.runner
import
obj_from_dict
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet
import
datasets
,
__version__
from
mmdet.core
import
(
init_dist
,
DistOptimizerHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
build_detector
,
RPN
def
parse_losses
(
losses
):
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
'{} is not a tensor or list of tensors'
.
format
(
loss_name
))
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
name
in
log_vars
:
log_vars
[
name
]
=
log_vars
[
name
].
item
()
return
loss
,
log_vars
def
batch_processor
(
model
,
data
,
train_mode
):
losses
=
model
(
**
data
)
loss
,
log_vars
=
parse_losses
(
losses
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
len
(
data
[
'img'
].
data
))
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
from
mmdet
import
datasets
def
set_random_seed
(
seed
):
from
mmdet.api
import
train_detector
random
.
seed
(
seed
)
from
mmdet.models
import
build_detector
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
def
parse_args
():
def
parse_args
():
...
@@ -86,71 +33,19 @@ def parse_args():
...
@@ -86,71 +33,19 @@ def parse_args():
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
work_dir
is
not
None
:
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
work_dir
=
args
.
work_dir
cfg
.
validate
=
args
.
validate
cfg
.
gpus
=
args
.
gpus
cfg
.
gpus
=
args
.
gpus
# save mmdet version in checkpoint as meta data
cfg
.
seed
=
args
.
seed
cfg
.
checkpoint_config
.
meta
=
dict
(
cfg
.
launcher
=
args
.
launcher
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
cfg
.
local_rank
=
args
.
local_rank
logger
=
get_logger
(
cfg
.
log_level
)
# set random seed if specified
if
args
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {}'
.
format
(
args
.
seed
))
set_random_seed
(
args
.
seed
)
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'Non-distributed training.'
)
else
:
dist
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Distributed training.'
)
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
data_loaders
=
[
build_dataloader
(
train_dataset
,
cfg
.
data
.
imgs_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
gpus
,
dist
)
]
# build model
# build model
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
if
dist
:
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
model
=
MMDistributedDataParallel
(
model
.
cuda
())
train_detector
(
model
,
train_dataset
,
cfg
)
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
range
(
cfg
.
gpus
)).
cuda
()
# build runner
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
cfg
.
log_level
)
# register hooks
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
if
dist
else
cfg
.
optimizer_config
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
args
.
validate
:
if
isinstance
(
model
.
module
,
RPN
):
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
,
cfg
.
total_epochs
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment