test_functions.py 3.85 KB
Newer Older
1
2
3
4
5
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcv.device._functions import Scatter, scatter
6
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, IS_NPU_AVAILABLE
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


def test_scatter():
    # if the device is CPU, just return the input
    input = torch.zeros([1, 3, 3, 3])
    output = scatter(input=input, devices=[-1])
    assert torch.allclose(input, output)

    inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
    outputs = scatter(input=inputs, devices=[-1])
    for input, output in zip(inputs, outputs):
        assert torch.allclose(input, output)

    # if the device is MLU, copy the input from CPU to MLU
    if IS_MLU_AVAILABLE:
        input = torch.zeros([1, 3, 3, 3])
        output = scatter(input=input, devices=[0])
        assert torch.allclose(input.to('mlu'), output)

        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = scatter(input=inputs, devices=[0])
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.to('mlu'), output)

31
32
33
34
35
36
37
38
39
40
41
    # if the device is NPU, copy the input from CPU to NPU
    if IS_NPU_AVAILABLE:
        input = torch.zeros([1, 3, 3, 3])
        output = scatter(input=input, devices=[0])
        assert torch.allclose(input.to('npu'), output)

        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = scatter(input=inputs, devices=[0])
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.to('npu'), output)

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    # if the device is MPS, copy the input from CPU to MPS
    if IS_MPS_AVAILABLE:
        input = torch.zeros([1, 3, 3, 3])
        output = scatter(input=input, devices=[0])
        assert torch.allclose(input.to('mps'), output)

        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = scatter(input=inputs, devices=[0])
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.to('mps'), output)

    # input should be a tensor or list of tensor
    with pytest.raises(Exception):
        scatter(5, [-1])


def test_Scatter():
    # if the device is CPU, just return the input
    target_devices = [-1]
    input = torch.zeros([1, 3, 3, 3])
    outputs = Scatter.forward(target_devices, input)
    assert isinstance(outputs, tuple)
    assert torch.allclose(input, outputs[0])

    target_devices = [-1]
    inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
    outputs = Scatter.forward(target_devices, inputs)
    assert isinstance(outputs, tuple)
    for input, output in zip(inputs, outputs):
        assert torch.allclose(input, output)

    # if the device is MLU, copy the input from CPU to MLU
    if IS_MLU_AVAILABLE:
        target_devices = [0]
        input = torch.zeros([1, 3, 3, 3])
        outputs = Scatter.forward(target_devices, input)
        assert isinstance(outputs, tuple)
        assert torch.allclose(input.to('mlu'), outputs[0])

        target_devices = [0]
        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = Scatter.forward(target_devices, inputs)
        assert isinstance(outputs, tuple)
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.to('mlu'), output[0])

    # if the device is MPS, copy the input from CPU to MPS
    if IS_MPS_AVAILABLE:
        target_devices = [0]
        input = torch.zeros([1, 3, 3, 3])
        outputs = Scatter.forward(target_devices, input)
        assert isinstance(outputs, tuple)
        assert torch.allclose(input.to('mps'), outputs[0])

        target_devices = [0]
        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = Scatter.forward(target_devices, inputs)
        assert isinstance(outputs, tuple)
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.to('mps'), output[0])