test_device_utils.py 576 Bytes
Newer Older
1
2
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device
3
4
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
                        IS_NPU_AVAILABLE)
5
6
7
8


def test_get_device():
    current_device = get_device()
9
10
11
    if IS_NPU_AVAILABLE:
        assert current_device == 'npu'
    elif IS_CUDA_AVAILABLE:
12
13
14
15
16
17
18
        assert current_device == 'cuda'
    elif IS_MLU_AVAILABLE:
        assert current_device == 'mlu'
    elif IS_MPS_AVAILABLE:
        assert current_device == 'mps'
    else:
        assert current_device == 'cpu'