Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
010b1a0f
Unverified
Commit
010b1a0f
authored
Apr 22, 2020
by
Kai Chen
Committed by
GitHub
Apr 22, 2020
Browse files
remove parallel_test (#238)
parent
af02ac9f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
79 deletions
+3
-79
mmcv/runner/__init__.py
mmcv/runner/__init__.py
+3
-4
mmcv/runner/parallel_test.py
mmcv/runner/parallel_test.py
+0
-75
No files found.
mmcv/runner/__init__.py
View file @
010b1a0f
...
@@ -7,7 +7,6 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
...
@@ -7,7 +7,6 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
OptimizerHook
,
PaviLoggerHook
,
TensorboardLoggerHook
,
OptimizerHook
,
PaviLoggerHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
WandbLoggerHook
)
TextLoggerHook
,
WandbLoggerHook
)
from
.log_buffer
import
LogBuffer
from
.log_buffer
import
LogBuffer
from
.parallel_test
import
parallel_test
from
.priority
import
Priority
,
get_priority
from
.priority
import
Priority
,
get_priority
from
.runner
import
Runner
from
.runner
import
Runner
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
...
@@ -17,7 +16,7 @@ __all__ = [
...
@@ -17,7 +16,7 @@ __all__ = [
'LrUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LoggerHook'
,
'PaviLoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'LoggerHook'
,
'PaviLoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'WandbLoggerHook'
,
'_load_checkpoint'
,
'load_state_dict'
,
'WandbLoggerHook'
,
'_load_checkpoint'
,
'load_state_dict'
,
'load_checkpoint'
,
'weights_to_cpu'
,
'save_checkpoint'
,
'
parallel_test
'
,
'load_checkpoint'
,
'weights_to_cpu'
,
'save_checkpoint'
,
'
Priority
'
,
'Priority'
,
'get_priority'
,
'get_host_info'
,
'get_time_str'
,
'get_priority'
,
'get_host_info'
,
'get_time_str'
,
'obj_from_dict'
,
'obj_from_dict'
,
'init_dist'
,
'get_dist_info'
,
'master_only'
'init_dist'
,
'get_dist_info'
,
'master_only'
]
]
mmcv/runner/parallel_test.py
deleted
100644 → 0
View file @
af02ac9f
# Copyright (c) Open-MMLab. All rights reserved.
import
multiprocessing
import
torch
import
mmcv
from
.checkpoint
import
load_checkpoint
def
worker_func
(
model_cls
,
model_kwargs
,
checkpoint
,
dataset
,
data_func
,
gpu_id
,
idx_queue
,
result_queue
):
model
=
model_cls
(
**
model_kwargs
)
load_checkpoint
(
model
,
checkpoint
,
map_location
=
'cpu'
)
torch
.
cuda
.
set_device
(
gpu_id
)
model
.
cuda
()
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
idx
=
idx_queue
.
get
()
data
=
dataset
[
idx
]
result
=
model
(
**
data_func
(
data
,
gpu_id
))
result_queue
.
put
((
idx
,
result
))
def
parallel_test
(
model_cls
,
model_kwargs
,
checkpoint
,
dataset
,
data_func
,
gpus
,
workers_per_gpu
=
1
):
"""Parallel testing on multiple GPUs.
Args:
model_cls (type): Model class type.
model_kwargs (dict): Arguments to init the model.
checkpoint (str): Checkpoint filepath.
dataset (:obj:`Dataset`): The dataset to be tested.
data_func (callable): The function that generates model inputs.
gpus (list[int]): GPU ids to be used.
workers_per_gpu (int): Number of processes on each GPU. It is possible
to run multiple workers on each GPU.
Returns:
list: Test results.
"""
ctx
=
multiprocessing
.
get_context
(
'spawn'
)
idx_queue
=
ctx
.
Queue
()
result_queue
=
ctx
.
Queue
()
num_workers
=
len
(
gpus
)
*
workers_per_gpu
workers
=
[
ctx
.
Process
(
target
=
worker_func
,
args
=
(
model_cls
,
model_kwargs
,
checkpoint
,
dataset
,
data_func
,
gpus
[
i
%
len
(
gpus
)],
idx_queue
,
result_queue
))
for
i
in
range
(
num_workers
)
]
for
w
in
workers
:
w
.
daemon
=
True
w
.
start
()
for
i
in
range
(
len
(
dataset
)):
idx_queue
.
put
(
i
)
results
=
[
None
for
_
in
range
(
len
(
dataset
))]
prog_bar
=
mmcv
.
ProgressBar
(
task_num
=
len
(
dataset
))
for
_
in
range
(
len
(
dataset
)):
idx
,
res
=
result_queue
.
get
()
results
[
idx
]
=
res
prog_bar
.
update
()
print
(
'
\n
'
)
for
worker
in
workers
:
worker
.
terminate
()
return
results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment