"vscode:/vscode.git/clone" did not exist on "ec38ba95b0cd6bf3dadfccf366cd8917acf59c4b"
Unverified Commit 0448fcf9 authored by bilibilee's avatar bilibilee Committed by GitHub
Browse files

fix the scatter when input is cpu (#1621)



* fix the scatter when input is cpu

* Update _functions.py

Add spaces to comply with the code specification

* add unittests

* add a blank line

* fix unittet
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent b8d78336
...@@ -22,10 +22,7 @@ def scatter(input, devices, streams=None): ...@@ -22,10 +22,7 @@ def scatter(input, devices, streams=None):
if devices != [-1]: if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream): with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True) output = output.cuda(devices[0], non_blocking=True)
else:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)
return output return output
else: else:
raise Exception(f'Unknown type {type(input)}.') raise Exception(f'Unknown type {type(input)}.')
...@@ -76,4 +73,4 @@ class Scatter: ...@@ -76,4 +73,4 @@ class Scatter:
if streams is not None: if streams is not None:
synchronize_stream(outputs, target_gpus, streams) synchronize_stream(outputs, target_gpus, streams)
return tuple(outputs) return tuple(outputs) if isinstance(outputs, list) else (outputs, )
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel, from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
MMDistributedDataParallel, is_module_wrapper) MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \ from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP MMDistributedDataParallel as DeprecatedMMDDP
...@@ -64,3 +66,83 @@ def test_is_module_wrapper(): ...@@ -64,3 +66,83 @@ def test_is_module_wrapper():
module_wraper = ModuleWrapper(model) module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper) assert is_module_wrapper(module_wraper)
def test_get_input_device():
# if the device is CPU, return -1
input = torch.zeros([1, 3, 3, 3])
assert get_input_device(input) == -1
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
assert get_input_device(inputs) == -1
# if the device is GPU, return the index of device
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3]).cuda()
assert get_input_device(input) == 0
inputs = [
torch.zeros([1, 3, 3, 3]).cuda(),
torch.zeros([1, 4, 4, 4]).cuda()
]
assert get_input_device(inputs) == 0
# input should be a tensor or list of tensor
with pytest.raises(Exception):
get_input_device(5)
def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.cuda(), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output)
# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])
def test_Scatter():
# if the device is CPU, just return the input
target_gpus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])
target_gpus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
target_gpus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.cuda(), outputs[0])
target_gpus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment