test_device.py 831 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import patch

import mmcv

from mmcls.utils import auto_select_device


class TestAutoSelectDevice(TestCase):

    @patch.object(mmcv, '__version__', '1.6.0')
    @patch('mmcv.device.get_device', create=True)
    def test_mmcv(self, mock):
        auto_select_device()
        mock.assert_called_once()

    @patch.object(mmcv, '__version__', '1.5.0')
    @patch('torch.cuda.is_available', return_value=True)
    def test_cuda(self, mock):
        device = auto_select_device()
        self.assertEqual(device, 'cuda')

    @patch.object(mmcv, '__version__', '1.5.0')
    @patch('torch.cuda.is_available', return_value=False)
    def test_cpu(self, mock):
        device = auto_select_device()
        self.assertEqual(device, 'cpu')