Commit b7894cbd authored by valuefish's avatar valuefish Committed by Kai Chen
Browse files

add multi nodes distributed test support (#1399)

* add multi nodes distributed test support

* fix bug in htc.py when keep_all_stages turn on

* remove package imported but unused in test.py

* reformat code in test.py

* support both cpu & gpu for gathering

* reformat

* clean code, add doc

* add docstring

* reformat doc string
parent 90357db2
import argparse
import os
import os.path as osp
import pickle
import shutil
import tempfile
......@@ -35,7 +36,25 @@ def single_gpu_test(model, data_loader, show=False):
return results
def multi_gpu_test(model, data_loader, tmpdir=None):
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
......@@ -53,12 +72,14 @@ def multi_gpu_test(model, data_loader, tmpdir=None):
prog_bar.update()
# collect results from all ranks
results = collect_results(results, len(dataset), tmpdir)
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results
def collect_results(result_part, size, tmpdir=None):
def collect_results_cpu(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
......@@ -100,6 +121,39 @@ def collect_results(result_part, size, tmpdir=None):
return ordered_results
def collect_results_gpu(result_part, size):
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path')
......@@ -116,6 +170,10 @@ def parse_args():
choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
help='eval types')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--gpu_collect',
action='store_true',
help='whether to use gpu to collect results')
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
parser.add_argument(
'--launcher',
......@@ -184,7 +242,8 @@ def main():
outputs = single_gpu_test(model, data_loader, args.show)
else:
model = MMDistributedDataParallel(model.cuda())
outputs = multi_gpu_test(model, data_loader, args.tmpdir)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if args.out and rank == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment