"vscode:/vscode.git/clone" did not exist on "98325b1097877b93dc872727d22ce2f402666e8f"
test_copy_to.py 336 Bytes
Newer Older
1
2
3
import unittest

import backend as F
4

5
6
7
8
9
10
11
12
13
14
15
import dgl.graphbolt
import torch


@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
    dp = dgl.graphbolt.MinibatchSampler(torch.randn(20), 4)
    dp = dgl.graphbolt.CopyTo(dp, "cuda")

    for data in dp:
        assert data.device.type == "cuda"