Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
c47568a4
Commit
c47568a4
authored
Apr 04, 2020
by
WXinlong
Browse files
support multi-gpu test
parent
357190f3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
3 deletions
+6
-3
tools/dist_test.sh
tools/dist_test.sh
+1
-1
tools/test_ins.py
tools/test_ins.py
+5
-2
No files found.
tools/dist_test.sh
View file @
c47568a4
...
@@ -8,4 +8,4 @@ GPUS=$3
...
@@ -8,4 +8,4 @@ GPUS=$3
PORT
=
${
PORT
:-
29500
}
PORT
=
${
PORT
:-
29500
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/test.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
$(
dirname
"
$0
"
)
/test
_ins
.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
tools/test_ins.py
View file @
c47568a4
...
@@ -63,12 +63,16 @@ def multi_gpu_test(model, data_loader, tmpdir=None):
...
@@ -63,12 +63,16 @@ def multi_gpu_test(model, data_loader, tmpdir=None):
model
.
eval
()
model
.
eval
()
results
=
[]
results
=
[]
dataset
=
data_loader
.
dataset
dataset
=
data_loader
.
dataset
num_classes
=
len
(
dataset
.
CLASSES
)
rank
,
world_size
=
get_dist_info
()
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
:
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
seg_result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
result
=
get_masks
(
seg_result
,
num_classes
=
num_classes
)
results
.
append
(
result
)
results
.
append
(
result
)
if
rank
==
0
:
if
rank
==
0
:
...
@@ -208,7 +212,6 @@ def main():
...
@@ -208,7 +212,6 @@ def main():
else
:
else
:
model
.
CLASSES
=
dataset
.
CLASSES
model
.
CLASSES
=
dataset
.
CLASSES
assert
not
distributed
if
not
distributed
:
if
not
distributed
:
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
outputs
=
single_gpu_test
(
model
,
data_loader
)
outputs
=
single_gpu_test
(
model
,
data_loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment