test_feature_fetcher.py 5.51 KB
Newer Older
1
2
import dgl.graphbolt as gb
import gb_test_utils
3
import torch
4
from torchdata.datapipes.iter import Mapper
5
6


7
8
9
10
def test_FeatureFetcher_homo():
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
    a = torch.randint(0, 10, (graph.num_nodes,))
    b = torch.randint(0, 10, (graph.num_edges,))
11

12
13
14
15
16
17
18
    features = {}
    keys = [("node", None, "a"), ("edge", None, "b")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)

    itemset = gb.ItemSet(torch.arange(10))
19
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
20
21
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
22
    data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
23
    sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
24
    fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    assert len(list(fetcher_dp)) == 5


def test_FeatureFetcher_with_edges_homo():
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
    a = torch.randint(0, 10, (graph.num_nodes,))
    b = torch.randint(0, 10, (graph.num_edges,))

    def add_node_and_edge_ids(seeds):
        subgraphs = []
        for _ in range(3):
            subgraphs.append(
                gb.SampledSubgraphImpl(
                    node_pairs=(torch.tensor([]), torch.tensor([])),
                    reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)),
                )
            )
        data = gb.NodeClassificationBlock(
            input_nodes=seeds, sampled_subgraphs=subgraphs
45
        )
46
47
48
49
50
51
52
53
54
        return data

    features = {}
    keys = [("node", None, "a"), ("edge", None, "b")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)

    itemset = gb.ItemSet(torch.arange(10))
55
56
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
    converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
57
    fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
58
59
60

    assert len(list(fetcher_dp)) == 5
    for data in fetcher_dp:
61
62
63
64
        assert data.node_features["a"].size(0) == 2
        assert len(data.edge_features) == 3
        for edge_feature in data.edge_features:
            assert edge_feature["b"].size(0) == 10
65
66
67
68
69
70
71
72
73


def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
74
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
75
76
77
78
79
80
81
82
83
84
85
86
    metadata = gb.GraphMetadata(ntypes, etypes)
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
    return gb.from_csc(
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        metadata=metadata,
    )
87
88


89
90
91
92
def test_FeatureFetcher_hetero():
    graph = get_hetero_graph()
    a = torch.randint(0, 10, (2,))
    b = torch.randint(0, 10, (3,))
93

94
95
96
97
98
    features = {}
    keys = [("node", "n1", "a"), ("node", "n2", "a")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)
99

100
101
102
103
104
105
    itemset = gb.ItemSetDict(
        {
            "n1": gb.ItemSet(torch.LongTensor([0, 1])),
            "n2": gb.ItemSet(torch.LongTensor([0, 1, 2])),
        }
    )
106
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
107
108
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
109
    data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
110
    sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
111
112
113
    fetcher_dp = gb.FeatureFetcher(
        sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
    )
114

115
116
117
118
119
120
121
122
123
124
    assert len(list(fetcher_dp)) == 3


def test_FeatureFetcher_with_edges_hetero():
    a = torch.randint(0, 10, (20,))
    b = torch.randint(0, 10, (50,))

    def add_node_and_edge_ids(seeds):
        subgraphs = []
        reverse_edge_ids = {
125
126
            "n1:e1:n2": torch.randint(0, 50, (10,)),
            "n2:e2:n1": torch.randint(0, 50, (10,)),
127
128
129
130
131
132
133
134
135
136
137
138
        }
        for _ in range(3):
            subgraphs.append(
                gb.SampledSubgraphImpl(
                    node_pairs=(torch.tensor([]), torch.tensor([])),
                    reverse_edge_ids=reverse_edge_ids,
                )
            )
        data = gb.NodeClassificationBlock(
            input_nodes=seeds, sampled_subgraphs=subgraphs
        )
        return data
139

140
141
142
143
144
    features = {}
    keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)
145

146
147
148
149
150
    itemset = gb.ItemSetDict(
        {
            "n1": gb.ItemSet(torch.randint(0, 20, (10,))),
        }
    )
151
152
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
    converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
153
154
155
    fetcher_dp = gb.FeatureFetcher(
        converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
    )
156
157

    assert len(list(fetcher_dp)) == 5
158
    for data in fetcher_dp:
159
160
161
        assert data.node_features[("n1", "a")].size(0) == 2
        assert len(data.edge_features) == 3
        for edge_feature in data.edge_features:
162
            assert edge_feature[("n1:e1:n2", "a")].size(0) == 10