test_base.py 4.51 KB
Newer Older
1
2
3
4
5
6
import re
import unittest

import backend as F

import dgl.graphbolt as gb
7
import gb_test_utils
8
9
10
11
12
13
import pytest
import torch


@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
14
    item_sampler = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
15
16

    # Invoke CopyTo via class constructor.
17
    dp = gb.CopyTo(item_sampler, "cuda")
18
19
    for data in dp:
        assert data.device.type == "cuda"
20

21
    # Invoke CopyTo via functional form.
22
    dp = item_sampler.copy_to("cuda")
23
24
25
26
    for data in dp:
        assert data.device.type == "cuda"


27
28
29
30
31
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyToWithMiniBatches():
    N = 16
    B = 2
    itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
32
    graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))
    features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))
    feature_store = gb.BasicFeatureStore(features)

    datapipe = gb.ItemSampler(itemset, batch_size=B)
    datapipe = gb.NeighborSampler(
        datapipe,
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
    )
    datapipe = gb.FeatureFetcher(
        datapipe,
        feature_store,
        ["a"],
    )

    def test_data_device(datapipe):
        for data in datapipe:
            for attr in dir(data):
                var = getattr(data, attr)
                if (
                    not callable(var)
                    and not attr.startswith("__")
                    and hasattr(var, "device")
                ):
                    assert var.device.type == "cuda"

    # Invoke CopyTo via class constructor.
    test_data_device(gb.CopyTo(datapipe, "cuda"))

    # Invoke CopyTo via functional form.
    test_data_device(datapipe.copy_to("cuda"))

    # Test for DGLMiniBatch.
    datapipe = gb.DGLMiniBatchConverter(datapipe)

    # Invoke CopyTo via class constructor.
    test_data_device(gb.CopyTo(datapipe, "cuda"))

    # Invoke CopyTo via functional form.
    test_data_device(datapipe.copy_to("cuda"))


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def test_etype_tuple_to_str():
    """Convert etype from tuple to string."""
    # Test for expected input.
    c_etype = ("user", "like", "item")
    c_etype_str = gb.etype_tuple_to_str(c_etype)
    assert c_etype_str == "user:like:item"

    # Test for unexpected input: not a tuple.
    c_etype = "user:like:item"
    with pytest.raises(
        AssertionError,
        match=re.escape(
            "Passed-in canonical etype should be in format of (str, str, str). "
            "But got user:like:item."
        ),
    ):
        _ = gb.etype_tuple_to_str(c_etype)

    # Test for unexpected input: tuple with wrong length.
    c_etype = ("user", "like")
    with pytest.raises(
        AssertionError,
        match=re.escape(
            "Passed-in canonical etype should be in format of (str, str, str). "
            "But got ('user', 'like')."
        ),
    ):
        _ = gb.etype_tuple_to_str(c_etype)


def test_etype_str_to_tuple():
    """Convert etype from string to tuple."""
    # Test for expected input.
    c_etype_str = "user:like:item"
    c_etype = gb.etype_str_to_tuple(c_etype_str)
    assert c_etype == ("user", "like", "item")

    # Test for unexpected input: string with wrong format.
    c_etype_str = "user:like"
    with pytest.raises(
        AssertionError,
        match=re.escape(
            "Passed-in canonical etype should be in format of 'str:str:str'. "
            "But got user:like."
        ),
    ):
        _ = gb.etype_str_to_tuple(c_etype_str)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152


def test_isin():
    elements = torch.tensor([2, 3, 5, 5, 20, 13, 11])
    test_elements = torch.tensor([2, 5])
    res = gb.isin(elements, test_elements)
    expected = torch.tensor([True, False, True, True, False, False, False])
    assert torch.equal(res, expected)


def test_isin_big_data():
    elements = torch.randint(0, 10000, (10000000,))
    test_elements = torch.randint(0, 10000, (500000,))
    res = gb.isin(elements, test_elements)
    expected = torch.isin(elements, test_elements)
    assert torch.equal(res, expected)


def test_isin_non_1D_dim():
    elements = torch.tensor([[2, 3], [5, 5], [20, 13]])
    test_elements = torch.tensor([2, 5])
    with pytest.raises(Exception):
        gb.isin(elements, test_elements)
    elements = torch.tensor([2, 3, 5, 5, 20, 13])
    test_elements = torch.tensor([[2, 5]])
    with pytest.raises(Exception):
        gb.isin(elements, test_elements)