Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
1059a42b
Commit
1059a42b
authored
Jun 24, 2020
by
Gus-Guo
Browse files
support multi-gpu testing
parent
acc9dd26
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
116 additions
and
15 deletions
+116
-15
pcdet/datasets/__init__.py
pcdet/datasets/__init__.py
+33
-3
pcdet/utils/common_utils.py
pcdet/utils/common_utils.py
+41
-0
tools/eval_utils/eval_utils.py
tools/eval_utils/eval_utils.py
+14
-4
tools/scripts/slurm_test_mgpu.sh
tools/scripts/slurm_test_mgpu.sh
+20
-0
tools/test.py
tools/test.py
+6
-6
tools/train.py
tools/train.py
+2
-2
No files found.
pcdet/datasets/__init__.py
View file @
1059a42b
...
@@ -2,13 +2,37 @@ import torch
...
@@ -2,13 +2,37 @@ import torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
.dataset
import
DatasetTemplate
from
.dataset
import
DatasetTemplate
from
.kitti.kitti_dataset
import
KittiDataset
from
.kitti.kitti_dataset
import
KittiDataset
from
torch.utils.data
import
DistributedSampler
as
_DistributedSampler
from
pcdet.utils
import
common_utils
__all__
=
{
__all__
=
{
'DatasetTemplate'
:
DatasetTemplate
,
'DatasetTemplate'
:
DatasetTemplate
,
'KittiDataset'
:
KittiDataset
,
'KittiDataset'
:
KittiDataset
,
}
}
class
DistributedSampler
(
_DistributedSampler
):
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
):
super
().
__init__
(
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
)
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
if
self
.
shuffle
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
indices
=
torch
.
arange
(
len
(
self
.
dataset
)).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
build_dataloader
(
dataset_cfg
,
class_names
,
batch_size
,
dist
,
root_path
=
None
,
workers
=
4
,
def
build_dataloader
(
dataset_cfg
,
class_names
,
batch_size
,
dist
,
root_path
=
None
,
workers
=
4
,
logger
=
None
,
training
=
True
):
logger
=
None
,
training
=
True
):
...
@@ -20,8 +44,14 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
...
@@ -20,8 +44,14 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
training
=
training
,
training
=
training
,
logger
=
logger
,
logger
=
logger
,
)
)
if
dist
:
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
)
if
dist
else
None
if
training
:
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
)
else
:
rank
,
world_size
=
common_utils
.
get_dist_info
()
sampler
=
DistributedSampler
(
dataset
,
world_size
,
rank
,
shuffle
=
False
)
else
:
sampler
=
None
dataloader
=
DataLoader
(
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
pin_memory
=
True
,
num_workers
=
workers
,
dataset
,
batch_size
=
batch_size
,
pin_memory
=
True
,
num_workers
=
workers
,
shuffle
=
(
sampler
is
None
)
and
training
,
collate_fn
=
dataset
.
collate_batch
,
shuffle
=
(
sampler
is
None
)
and
training
,
collate_fn
=
dataset
.
collate_batch
,
...
...
pcdet/utils/common_utils.py
View file @
1059a42b
...
@@ -6,6 +6,8 @@ import os
...
@@ -6,6 +6,8 @@ import os
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
subprocess
import
subprocess
import
pickle
import
shutil
def
check_numpy_to_torch
(
x
):
def
check_numpy_to_torch
(
x
):
...
@@ -153,3 +155,42 @@ def init_dist_pytorch(batch_size, tcp_port, local_rank, backend='nccl'):
...
@@ -153,3 +155,42 @@ def init_dist_pytorch(batch_size, tcp_port, local_rank, backend='nccl'):
batch_size_each_gpu
=
batch_size
//
num_gpus
batch_size_each_gpu
=
batch_size
//
num_gpus
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
return
batch_size_each_gpu
,
rank
return
batch_size_each_gpu
,
rank
def
get_dist_info
():
if
torch
.
__version__
<
'1.0'
:
initialized
=
dist
.
_initialized
else
:
if
dist
.
is_available
():
initialized
=
dist
.
is_initialized
()
else
:
initialized
=
False
if
initialized
:
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
return
rank
,
world_size
def
merge_results_dist
(
result_part
,
size
,
tmpdir
):
rank
,
world_size
=
get_dist_info
()
os
.
makedirs
(
tmpdir
,
exist_ok
=
True
)
dist
.
barrier
()
pickle
.
dump
(
result_part
,
open
(
os
.
path
.
join
(
tmpdir
,
'result_part_{}.pkl'
.
format
(
rank
)),
'wb'
))
dist
.
barrier
()
if
rank
!=
0
:
return
None
part_list
=
[]
for
i
in
range
(
world_size
):
part_file
=
os
.
path
.
join
(
tmpdir
,
'result_part_{}.pkl'
.
format
(
i
))
part_list
.
append
(
pickle
.
load
(
open
(
part_file
,
'rb'
)))
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
ordered_results
=
ordered_results
[:
size
]
shutil
.
rmtree
(
tmpdir
)
return
ordered_results
tools/eval_utils/eval_utils.py
View file @
1059a42b
...
@@ -3,7 +3,7 @@ import time
...
@@ -3,7 +3,7 @@ import time
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mm
pcdet.utils
import
common_utils
from
pcdet.utils
import
common_utils
def
statistics_info
(
cfg
,
ret_dict
,
metric
,
disp_dict
):
def
statistics_info
(
cfg
,
ret_dict
,
metric
,
disp_dict
):
...
@@ -38,7 +38,13 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
...
@@ -38,7 +38,13 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
logger
.
info
(
'*************** EPOCH %s EVALUATION *****************'
%
epoch_id
)
logger
.
info
(
'*************** EPOCH %s EVALUATION *****************'
%
epoch_id
)
if
dist_test
:
if
dist_test
:
raise
NotImplementedError
num_gpus
=
torch
.
cuda
.
device_count
()
local_rank
=
cfg
.
LOCAL_RANK
%
num_gpus
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
local_rank
],
broadcast_buffers
=
False
)
model
.
eval
()
model
.
eval
()
if
cfg
.
LOCAL_RANK
==
0
:
if
cfg
.
LOCAL_RANK
==
0
:
...
@@ -71,7 +77,8 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
...
@@ -71,7 +77,8 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
if
dist_test
:
if
dist_test
:
rank
,
world_size
=
common_utils
.
get_dist_info
()
rank
,
world_size
=
common_utils
.
get_dist_info
()
raise
NotImplementedError
det_annos
=
common_utils
.
merge_results_dist
(
det_annos
,
len
(
dataset
),
tmpdir
=
result_dir
/
'tmpdir'
)
metric
=
common_utils
.
merge_results_dist
([
metric
],
world_size
,
tmpdir
=
result_dir
/
'tmpdir'
)
logger
.
info
(
'*************** Performance of EPOCH %s *****************'
%
epoch_id
)
logger
.
info
(
'*************** Performance of EPOCH %s *****************'
%
epoch_id
)
sec_per_example
=
(
time
.
time
()
-
start_time
)
/
len
(
dataloader
.
dataset
)
sec_per_example
=
(
time
.
time
()
-
start_time
)
/
len
(
dataloader
.
dataset
)
...
@@ -82,7 +89,10 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
...
@@ -82,7 +89,10 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
ret_dict
=
{}
ret_dict
=
{}
if
dist_test
:
if
dist_test
:
raise
NotImplementedError
for
key
,
val
in
metric
[
0
].
items
():
for
k
in
range
(
1
,
world_size
):
metric
[
0
][
key
]
+=
metric
[
k
][
key
]
metric
=
metric
[
0
]
gt_num_cnt
=
metric
[
'gt_num'
]
gt_num_cnt
=
metric
[
'gt_num'
]
for
cur_thresh
in
cfg
.
MODEL
.
POST_PROCESSING
.
RECALL_THRESH_LIST
:
for
cur_thresh
in
cfg
.
MODEL
.
POST_PROCESSING
.
RECALL_THRESH_LIST
:
...
...
tools/scripts/slurm_test_mgpu.sh
0 → 100755
View file @
1059a42b
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
GPUS
=
$2
GPUS_PER_NODE
=
$GPUS
PY_ARGS
=
${
@
:3
}
JOB_NAME
=
eval
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
srun
-p
${
PARTITION
}
\
--job-name
=
${
JOB_NAME
}
\
--gres
=
gpu:
${
GPUS_PER_NODE
}
\
--ntasks
=
${
GPUS
}
\
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--kill-on-bad-exit
=
1
\
${
SRUN_ARGS
}
\
python
-u
test.py
--launcher
slurm
${
PY_ARGS
}
tools/test.py
View file @
1059a42b
...
@@ -8,14 +8,14 @@ import datetime
...
@@ -8,14 +8,14 @@ import datetime
import
argparse
import
argparse
from
pathlib
import
Path
from
pathlib
import
Path
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
mm
pcdet.datasets
import
build_dataloader
from
pcdet.datasets
import
build_dataloader
from
mm
pcdet.models
import
build_network
from
pcdet.models
import
build_network
from
mm
pcdet.utils
import
common_utils
from
pcdet.utils
import
common_utils
from
mm
pcdet.config
import
cfg
,
cfg_from_list
,
cfg_from_yaml_file
,
log_config_to_file
from
pcdet.config
import
cfg
,
cfg_from_list
,
cfg_from_yaml_file
,
log_config_to_file
from
eval_utils
import
eval_utils
from
eval_utils
import
eval_utils
def
par
g
e_config
():
def
par
s
e_config
():
parser
=
argparse
.
ArgumentParser
(
description
=
'arg parser'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'arg parser'
)
parser
.
add_argument
(
'--cfg_file'
,
type
=
str
,
default
=
None
,
help
=
'specify the config for training'
)
parser
.
add_argument
(
'--cfg_file'
,
type
=
str
,
default
=
None
,
help
=
'specify the config for training'
)
...
@@ -128,7 +128,7 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir
...
@@ -128,7 +128,7 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir
def
main
():
def
main
():
args
,
cfg
=
par
g
e_config
()
args
,
cfg
=
par
s
e_config
()
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
dist_test
=
False
dist_test
=
False
else
:
else
:
...
...
tools/train.py
View file @
1059a42b
...
@@ -16,7 +16,7 @@ import datetime
...
@@ -16,7 +16,7 @@ import datetime
import
glob
import
glob
def
par
g
e_config
():
def
par
s
e_config
():
parser
=
argparse
.
ArgumentParser
(
description
=
'arg parser'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'arg parser'
)
parser
.
add_argument
(
'--cfg_file'
,
type
=
str
,
default
=
None
,
help
=
'specify the config for training'
)
parser
.
add_argument
(
'--cfg_file'
,
type
=
str
,
default
=
None
,
help
=
'specify the config for training'
)
...
@@ -54,7 +54,7 @@ def parge_config():
...
@@ -54,7 +54,7 @@ def parge_config():
def
main
():
def
main
():
args
,
cfg
=
par
g
e_config
()
args
,
cfg
=
par
s
e_config
()
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
dist_train
=
False
dist_train
=
False
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment