Commit c47568a4 authored by WXinlong's avatar WXinlong
Browse files

support multi-gpu test

parent 357190f3
...@@ -8,4 +8,4 @@ GPUS=$3 ...@@ -8,4 +8,4 @@ GPUS=$3
PORT=${PORT:-29500} PORT=${PORT:-29500}
$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} $(dirname "$0")/test_ins.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
...@@ -63,12 +63,16 @@ def multi_gpu_test(model, data_loader, tmpdir=None): ...@@ -63,12 +63,16 @@ def multi_gpu_test(model, data_loader, tmpdir=None):
model.eval() model.eval()
results = [] results = []
dataset = data_loader.dataset dataset = data_loader.dataset
num_classes = len(dataset.CLASSES)
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
if rank == 0: if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset)) prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) seg_result = model(return_loss=False, rescale=True, **data)
result = get_masks(seg_result, num_classes=num_classes)
results.append(result) results.append(result)
if rank == 0: if rank == 0:
...@@ -208,7 +212,6 @@ def main(): ...@@ -208,7 +212,6 @@ def main():
else: else:
model.CLASSES = dataset.CLASSES model.CLASSES = dataset.CLASSES
assert not distributed
if not distributed: if not distributed:
model = MMDataParallel(model, device_ids=[0]) model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader) outputs = single_gpu_test(model, data_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment