test_multi_gpu.py 454 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
import pytest
import torch
from torch_scatter import scatter_max


@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS')
def test_multi_gpu():
    device = torch.device('cuda:1')
    src = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device)
    index = torch.tensor([0, 0, 1, 1], device=device)
    assert scatter_max(src, index)[0].tolist() == [3, 5]