Commit 1059a42b authored by Gus-Guo's avatar Gus-Guo
Browse files

support multi-gpu testing

parent acc9dd26
...@@ -2,13 +2,37 @@ import torch ...@@ -2,13 +2,37 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .dataset import DatasetTemplate from .dataset import DatasetTemplate
from .kitti.kitti_dataset import KittiDataset from .kitti.kitti_dataset import KittiDataset
from torch.utils.data import DistributedSampler as _DistributedSampler
from pcdet.utils import common_utils
__all__ = { __all__ = {
'DatasetTemplate': DatasetTemplate, 'DatasetTemplate': DatasetTemplate,
'KittiDataset': KittiDataset, 'KittiDataset': KittiDataset,
} }
class DistributedSampler(_DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4, def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4,
logger=None, training=True): logger=None, training=True):
...@@ -20,8 +44,14 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, ...@@ -20,8 +44,14 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
training=training, training=training,
logger=logger, logger=logger,
) )
if dist:
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if dist else None if training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
rank, world_size = common_utils.get_dist_info()
sampler = DistributedSampler(dataset, world_size, rank, shuffle=False)
else:
sampler = None
dataloader = DataLoader( dataloader = DataLoader(
dataset, batch_size=batch_size, pin_memory=True, num_workers=workers, dataset, batch_size=batch_size, pin_memory=True, num_workers=workers,
shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch, shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch,
......
...@@ -6,6 +6,8 @@ import os ...@@ -6,6 +6,8 @@ import os
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
import subprocess import subprocess
import pickle
import shutil
def check_numpy_to_torch(x): def check_numpy_to_torch(x):
...@@ -153,3 +155,42 @@ def init_dist_pytorch(batch_size, tcp_port, local_rank, backend='nccl'): ...@@ -153,3 +155,42 @@ def init_dist_pytorch(batch_size, tcp_port, local_rank, backend='nccl'):
batch_size_each_gpu = batch_size // num_gpus batch_size_each_gpu = batch_size // num_gpus
rank = dist.get_rank() rank = dist.get_rank()
return batch_size_each_gpu, rank return batch_size_each_gpu, rank
def get_dist_info():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def merge_results_dist(result_part, size, tmpdir):
rank, world_size = get_dist_info()
os.makedirs(tmpdir, exist_ok=True)
dist.barrier()
pickle.dump(result_part, open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(rank)), 'wb'))
dist.barrier()
if rank != 0:
return None
part_list = []
for i in range(world_size):
part_file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i))
part_list.append(pickle.load(open(part_file, 'rb')))
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
ordered_results = ordered_results[:size]
shutil.rmtree(tmpdir)
return ordered_results
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
import pickle import pickle
import numpy as np import numpy as np
import torch import torch
from mmpcdet.utils import common_utils from pcdet.utils import common_utils
def statistics_info(cfg, ret_dict, metric, disp_dict): def statistics_info(cfg, ret_dict, metric, disp_dict):
...@@ -38,7 +38,13 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa ...@@ -38,7 +38,13 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id) logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id)
if dist_test: if dist_test:
raise NotImplementedError num_gpus = torch.cuda.device_count()
local_rank = cfg.LOCAL_RANK % num_gpus
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
broadcast_buffers=False
)
model.eval() model.eval()
if cfg.LOCAL_RANK == 0: if cfg.LOCAL_RANK == 0:
...@@ -71,7 +77,8 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa ...@@ -71,7 +77,8 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
if dist_test: if dist_test:
rank, world_size = common_utils.get_dist_info() rank, world_size = common_utils.get_dist_info()
raise NotImplementedError det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir')
metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir')
logger.info('*************** Performance of EPOCH %s *****************' % epoch_id) logger.info('*************** Performance of EPOCH %s *****************' % epoch_id)
sec_per_example = (time.time() - start_time) / len(dataloader.dataset) sec_per_example = (time.time() - start_time) / len(dataloader.dataset)
...@@ -82,7 +89,10 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa ...@@ -82,7 +89,10 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
ret_dict = {} ret_dict = {}
if dist_test: if dist_test:
raise NotImplementedError for key, val in metric[0].items():
for k in range(1, world_size):
metric[0][key] += metric[k][key]
metric = metric[0]
gt_num_cnt = metric['gt_num'] gt_num_cnt = metric['gt_num']
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
......
#!/usr/bin/env bash
set -x
PARTITION=$1
GPUS=$2
GPUS_PER_NODE=$GPUS
PY_ARGS=${@:3}
JOB_NAME=eval
SRUN_ARGS=${SRUN_ARGS:-""}
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u test.py --launcher slurm ${PY_ARGS}
...@@ -8,14 +8,14 @@ import datetime ...@@ -8,14 +8,14 @@ import datetime
import argparse import argparse
from pathlib import Path from pathlib import Path
import torch.distributed as dist import torch.distributed as dist
from mmpcdet.datasets import build_dataloader from pcdet.datasets import build_dataloader
from mmpcdet.models import build_network from pcdet.models import build_network
from mmpcdet.utils import common_utils from pcdet.utils import common_utils
from mmpcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file from pcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
from eval_utils import eval_utils from eval_utils import eval_utils
def parge_config(): def parse_config():
parser = argparse.ArgumentParser(description='arg parser') parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training') parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')
...@@ -128,7 +128,7 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir ...@@ -128,7 +128,7 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir
def main(): def main():
args, cfg = parge_config() args, cfg = parse_config()
if args.launcher == 'none': if args.launcher == 'none':
dist_test = False dist_test = False
else: else:
......
...@@ -16,7 +16,7 @@ import datetime ...@@ -16,7 +16,7 @@ import datetime
import glob import glob
def parge_config(): def parse_config():
parser = argparse.ArgumentParser(description='arg parser') parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training') parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')
...@@ -54,7 +54,7 @@ def parge_config(): ...@@ -54,7 +54,7 @@ def parge_config():
def main(): def main():
args, cfg = parge_config() args, cfg = parse_config()
if args.launcher == 'none': if args.launcher == 'none':
dist_train = False dist_train = False
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment