"src/diffusers/models/controlnets/controlnet.py" did not exist on "c375903db58826494d858e02b44d21b42669ff5e"
test_feature_fetcher.py 8.93 KB
Newer Older
1
import random
2
from enum import Enum
3

4
5
import dgl.graphbolt as gb
import gb_test_utils
6
import pytest
7
import torch
8
from torchdata.datapipes.iter import Mapper
9
10


11
12
13
14
15
16
17
18
19
class MiniBatchType(Enum):
    MiniBatch = 1
    DGLMiniBatch = 2


@pytest.mark.parametrize(
    "minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_invoke(minibatch_type):
20
    # Prepare graph and required datapipes.
21
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
22
23
24
25
26
27
    a = torch.tensor(
        [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
    )
    b = torch.tensor(
        [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]
    )
28
29
30
31
32
33
34
35

    features = {}
    keys = [("node", None, "a"), ("edge", None, "b")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)

    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
36
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
37
38
39
40
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    # Invoke FeatureFetcher via class constructor.
41
    datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)
42
43
44
    if minibatch_type == MiniBatchType.DGLMiniBatch:
        datapipe = datapipe.to_dgl()

45
46
47
48
    datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
    assert len(list(datapipe)) == 5

    # Invoke FeatureFetcher via functional form.
49
    datapipe = item_sampler.sample_neighbor(graph, fanouts).fetch_feature(
50
51
52
53
54
        feature_store, ["a"], ["b"]
    )
    assert len(list(datapipe)) == 5


55
56
57
58
@pytest.mark.parametrize(
    "minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_homo(minibatch_type):
59
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
60
61
62
63
64
65
    a = torch.tensor(
        [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
    )
    b = torch.tensor(
        [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]
    )
66

67
68
69
70
71
72
    features = {}
    keys = [("node", None, "a"), ("edge", None, "b")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)

73
74
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
75
76
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
77
    sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
78
79
    if minibatch_type == MiniBatchType.DGLMiniBatch:
        sampler_dp = sampler_dp.to_dgl()
80
    fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
81
82
83
84

    assert len(list(fetcher_dp)) == 5


85
86
87
88
@pytest.mark.parametrize(
    "minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_homo(minibatch_type):
89
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
90
91
92
93
94
95
    a = torch.tensor(
        [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
    )
    b = torch.tensor(
        [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]
    )
96
97
98
99

    def add_node_and_edge_ids(seeds):
        subgraphs = []
        for _ in range(3):
100
            range_tensor = torch.arange(10)
101
            subgraphs.append(
102
                gb.FusedSampledSubgraphImpl(
103
104
105
                    node_pairs=(range_tensor, range_tensor),
                    original_column_node_ids=range_tensor,
                    original_row_node_ids=range_tensor,
106
107
108
                    original_edge_ids=torch.randint(
                        0, graph.total_num_edges, (10,)
                    ),
109
110
                )
            )
111
        data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
112
113
114
115
116
117
118
119
120
        return data

    features = {}
    keys = [("node", None, "a"), ("edge", None, "b")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)

    itemset = gb.ItemSet(torch.arange(10))
121
122
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
    converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
123
124
    if minibatch_type == MiniBatchType.DGLMiniBatch:
        converter_dp = converter_dp.to_dgl()
125
    fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
126
127
128

    assert len(list(fetcher_dp)) == 5
    for data in fetcher_dp:
129
130
131
132
        assert data.node_features["a"].size(0) == 2
        assert len(data.edge_features) == 3
        for edge_feature in data.edge_features:
            assert edge_feature["b"].size(0) == 10
133
134
135
136
137
138
139
140
141


def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
142
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
143
144
145
146
147
    metadata = gb.GraphMetadata(ntypes, etypes)
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
148
    return gb.from_fused_csc(
149
150
151
152
153
154
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        metadata=metadata,
    )
155
156


157
158
159
160
@pytest.mark.parametrize(
    "minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_hetero(minibatch_type):
161
    graph = get_hetero_graph()
162
163
    a = torch.tensor([[random.randint(0, 10)] for _ in range(2)])
    b = torch.tensor([[random.randint(0, 10)] for _ in range(3)])
164

165
166
167
168
169
    features = {}
    keys = [("node", "n1", "a"), ("node", "n2", "a")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)
170

171
172
    itemset = gb.ItemSetDict(
        {
173
174
            "n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"),
            "n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"),
175
176
        }
    )
177
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
178
179
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
180
    sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
181
182
    if minibatch_type == MiniBatchType.DGLMiniBatch:
        sampler_dp = sampler_dp.to_dgl()
183
184
185
    fetcher_dp = gb.FeatureFetcher(
        sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
    )
186

187
188
189
    assert len(list(fetcher_dp)) == 3


190
191
192
193
@pytest.mark.parametrize(
    "minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_hetero(minibatch_type):
194
195
    a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
    b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
196
197
198

    def add_node_and_edge_ids(seeds):
        subgraphs = []
199
        original_edge_ids = {
200
201
            "n1:e1:n2": torch.randint(0, 50, (10,)),
            "n2:e2:n1": torch.randint(0, 50, (10,)),
202
        }
203
204
205
206
207
208
209
210
        original_column_node_ids = {
            "n1": torch.randint(0, 20, (10,)),
            "n2": torch.randint(0, 20, (10,)),
        }
        original_row_node_ids = {
            "n1": torch.randint(0, 20, (10,)),
            "n2": torch.randint(0, 20, (10,)),
        }
211
212
        for _ in range(3):
            subgraphs.append(
213
                gb.FusedSampledSubgraphImpl(
214
215
216
217
218
219
220
221
222
223
224
225
                    node_pairs={
                        "n1:e1:n2": (
                            torch.arange(10),
                            torch.arange(10),
                        ),
                        "n2:e2:n1": (
                            torch.arange(10),
                            torch.arange(10),
                        ),
                    },
                    original_column_node_ids=original_column_node_ids,
                    original_row_node_ids=original_row_node_ids,
226
                    original_edge_ids=original_edge_ids,
227
228
                )
            )
229
        data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
230
        return data
231

232
233
234
235
236
    features = {}
    keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
    features[keys[0]] = gb.TorchBasedFeature(a)
    features[keys[1]] = gb.TorchBasedFeature(b)
    feature_store = gb.BasicFeatureStore(features)
237

238
239
240
241
242
    itemset = gb.ItemSetDict(
        {
            "n1": gb.ItemSet(torch.randint(0, 20, (10,))),
        }
    )
243
244
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
    converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
245
246
    if minibatch_type == MiniBatchType.DGLMiniBatch:
        converter_dp = converter_dp.to_dgl()
247
248
249
    fetcher_dp = gb.FeatureFetcher(
        converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
    )
250
251

    assert len(list(fetcher_dp)) == 5
252
    for data in fetcher_dp:
253
254
255
        assert data.node_features[("n1", "a")].size(0) == 2
        assert len(data.edge_features) == 3
        for edge_feature in data.edge_features:
256
            assert edge_feature[("n1:e1:n2", "a")].size(0) == 10