Unverified Commit 0448fcf9 authored by bilibilee's avatar bilibilee Committed by GitHub
Browse files

fix the scatter when input is cpu (#1621)



* fix the scatter when input is cpu

* Update _functions.py

Add spaces to comply with the code specification

* add unittests

* add a blank line

* fix unittet
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent b8d78336
......@@ -22,10 +22,7 @@ def scatter(input, devices, streams=None):
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
else:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)
return output
else:
raise Exception(f'Unknown type {type(input)}.')
......@@ -76,4 +73,4 @@ class Scatter:
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)
return tuple(outputs)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP
......@@ -64,3 +66,83 @@ def test_is_module_wrapper():
module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)
def test_get_input_device():
# if the device is CPU, return -1
input = torch.zeros([1, 3, 3, 3])
assert get_input_device(input) == -1
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
assert get_input_device(inputs) == -1
# if the device is GPU, return the index of device
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3]).cuda()
assert get_input_device(input) == 0
inputs = [
torch.zeros([1, 3, 3, 3]).cuda(),
torch.zeros([1, 4, 4, 4]).cuda()
]
assert get_input_device(inputs) == 0
# input should be a tensor or list of tensor
with pytest.raises(Exception):
get_input_device(5)
def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.cuda(), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output)
# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])
def test_Scatter():
# if the device is CPU, just return the input
target_gpus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])
target_gpus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
target_gpus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.cuda(), outputs[0])
target_gpus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment