test_partition.py 24.3 KB
Newer Older
1
import os
2
3
4

import backend as F
import dgl
5
import numpy as np
6
import pytest
7
from dgl import function as fn
8
9
10
11
12
13
14
15
16
17
18
19
20
from dgl.distributed import (load_partition, load_partition_feats,
                             partition_graph)
from dgl.distributed.graph_partition_book import (BasicPartitionBook,
                                                  EdgePartitionPolicy,
                                                  HeteroDataName,
                                                  NodePartitionPolicy,
                                                  RangePartitionBook)
from dgl.distributed.partition import (
    RESERVED_FIELD_DTYPE,
    _get_inner_node_mask,
    _get_inner_edge_mask
)
from scipy import sparse as spsp
21
from utils import reset_envs
22

23
24
25
26
27
28
29

def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
30

31

32
33
34
35
def _verify_partition_formats(part_g, formats):
    # verify saved graph formats
    if formats is None:
        assert "coo" in part_g.formats()["created"]
36
    else:
37
38
        for format in formats:
            assert format in part_g.formats()["created"]
39
40


41
def create_random_graph(n):
42
43
44
    arr = (
        spsp.random(n, n, density=0.001, format="coo", random_state=100) != 0
    ).astype(np.int64)
45
    return dgl.from_scipy(arr)
46

47

48
def create_random_hetero():
49
50
    num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
51
52
53
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
54
55
56
57
58
59
60
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
61
62
63
        edges[etype] = (arr.row, arr.col)
    return dgl.heterograph(edges, num_nodes)

64

65
def verify_hetero_graph(g, parts):
66
67
    num_nodes = {ntype: 0 for ntype in g.ntypes}
    num_edges = {etype: 0 for etype in g.etypes}
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    for part in parts:
        assert len(g.ntypes) == len(F.unique(part.ndata[dgl.NTYPE]))
        assert len(g.etypes) == len(F.unique(part.edata[dgl.ETYPE]))
        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0)
            num_nodes[ntype] += num_inner_nodes
        for etype in g.etypes:
            etype_id = g.get_etype_id(etype)
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
            num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0)
            num_edges[etype] += num_inner_edges
    # Verify the number of nodes are correct.
    for ntype in g.ntypes:
83
84
85
86
87
        print(
            "node {}: {}, {}".format(
                ntype, g.number_of_nodes(ntype), num_nodes[ntype]
            )
        )
88
89
90
        assert g.number_of_nodes(ntype) == num_nodes[ntype]
    # Verify the number of edges are correct.
    for etype in g.etypes:
91
92
93
94
95
        print(
            "edge {}: {}, {}".format(
                etype, g.number_of_edges(etype), num_edges[etype]
            )
        )
96
97
        assert g.number_of_edges(etype) == num_edges[etype]

98
99
    nids = {ntype: [] for ntype in g.ntypes}
    eids = {etype: [] for etype in g.etypes}
100
    for part in parts:
101
        _, _, eid = part.edges(form="all")
102
103
104
105
106
107
108
        etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid)
        eid_type = F.gather_row(part.edata[dgl.EID], eid)
        for etype in g.etypes:
            etype_id = g.get_etype_id(etype)
            eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id))
            # Make sure edge Ids fall into a range.
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
109
110
111
112
113
114
            inner_eids = np.sort(
                F.asnumpy(F.boolean_mask(part.edata[dgl.EID], inner_edge_mask))
            )
            assert np.all(
                inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1)
            )
115
116
117
118
119
120

        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            # Make sure inner nodes have Ids fall into a range.
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
121
122
123
124
125
126
127
128
129
            assert np.all(
                F.asnumpy(
                    inner_nids
                    == F.arange(
                        F.as_scalar(inner_nids[0]),
                        F.as_scalar(inner_nids[-1]) + 1,
                    )
                )
            )
130
131
132
133
134
135
136
137
138
139
140
141
142
            nids[ntype].append(inner_nids)

    for ntype in nids:
        nids_type = F.cat(nids[ntype], 0)
        uniq_ids = F.unique(nids_type)
        # We should get all nodes.
        assert len(uniq_ids) == g.number_of_nodes(ntype)
    for etype in eids:
        eids_type = F.cat(eids[etype], 0)
        uniq_ids = F.unique(eids_type)
        assert len(uniq_ids) == g.number_of_edges(etype)
    # TODO(zhengda) this doesn't check 'part_id'

143
144
145
146

def verify_graph_feats(
    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids
):
147
148
    for ntype in g.ntypes:
        ntype_id = g.get_ntype_id(ntype)
149
        inner_node_mask = _get_inner_node_mask(part, ntype_id)
150
        inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
151
152
153
154
155
        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
        partid = gpb.nid2partid(inner_type_nids, ntype)
        assert np.all(F.asnumpy(ntype_ids) == ntype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

156
        orig_id = orig_nids[ntype][inner_type_nids]
157
158
        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)

159
        for name in g.nodes[ntype].data:
160
            if name in [dgl.NID, "inner_node"]:
161
162
                continue
            true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)
163
            ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids)
164
165
            assert np.all(F.asnumpy(ndata == true_feats))

166
167
168
    for etype in g.etypes:
        etype_id = g.get_etype_id(etype)
        inner_edge_mask = _get_inner_edge_mask(part, etype_id)
169
        inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask)
170
171
172
173
174
        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
        partid = gpb.eid2partid(inner_type_eids, etype)
        assert np.all(F.asnumpy(etype_ids) == etype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

175
        orig_id = orig_eids[etype][inner_type_eids]
176
177
178
        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)

        for name in g.edges[etype].data:
179
            if name in [dgl.EID, "inner_edge"]:
180
181
                continue
            true_feats = F.gather_row(g.edges[etype].data[name], orig_id)
182
            edata = F.gather_row(edge_feats[etype + "/" + name], local_eids)
183
184
            assert np.all(F.asnumpy(edata == true_feats))

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

def check_hetero_partition(
    hg,
    part_method,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
    hg.nodes["n1"].data["labels"] = F.arange(0, hg.number_of_nodes("n1"))
    hg.nodes["n1"].data["feats"] = F.tensor(
        np.random.randn(hg.number_of_nodes("n1"), 10), F.float32
    )
    hg.edges["r1"].data["feats"] = F.tensor(
        np.random.randn(hg.number_of_edges("r1"), 10), F.float32
    )
    hg.edges["r1"].data["labels"] = F.arange(0, hg.number_of_edges("r1"))
202
203
    num_hops = 1

204
205
206
207
208
209
210
211
212
213
214
215
    orig_nids, orig_eids = partition_graph(
        hg,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        reshuffle=True,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
216
217
218
219
220
221
    assert len(orig_nids) == len(hg.ntypes)
    assert len(orig_eids) == len(hg.etypes)
    for ntype in hg.ntypes:
        assert len(orig_nids[ntype]) == hg.number_of_nodes(ntype)
    for etype in hg.etypes:
        assert len(orig_eids[etype]) == hg.number_of_edges(etype)
222
    parts = []
223
224
    shuffled_labels = []
    shuffled_elabels = []
225
    for i in range(num_parts):
226
        part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
227
228
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
229
        _verify_partition_data_types(part_g)
230
        _verify_partition_formats(part_g, graph_formats)
231
232
233
        if not load_feats:
            assert not node_feats
            assert not edge_feats
234
235
236
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
237
238
        if num_trainers_per_machine > 1:
            for ntype in hg.ntypes:
239
                name = ntype + "/trainer_id"
240
                assert name in node_feats
241
242
243
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
244
245
246
                assert np.all(F.asnumpy(part_ids) == i)

            for etype in hg.etypes:
247
                name = etype + "/trainer_id"
248
                assert name in edge_feats
249
250
251
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
252
                assert np.all(F.asnumpy(part_ids) == i)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        # Verify the mapping between the reshuffled IDs and the original IDs.
        # These are partition-local IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        # These are reshuffled global homogeneous IDs.
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        # These are reshuffled per-type IDs.
        src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)
        dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)
        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)
        # These are original per-type IDs.
        for etype_id, etype in enumerate(hg.etypes):
            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)
267
268
269
            src_ntype_ids1 = F.boolean_mask(
                src_ntype_ids, etype_ids == etype_id
            )
270
            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)
271
272
273
            dst_ntype_ids1 = F.boolean_mask(
                dst_ntype_ids, etype_ids == etype_id
            )
274
275
276
277
278
279
280
281
282
283
284
            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)
            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))
            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))
            src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])]
            dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])]
            orig_src_ids1 = F.gather_row(orig_nids[src_ntype], part_src_ids1)
            orig_dst_ids1 = F.gather_row(orig_nids[dst_ntype], part_dst_ids1)
            orig_eids1 = F.gather_row(orig_eids[etype], part_eids1)
            orig_eids2 = hg.edge_ids(orig_src_ids1, orig_dst_ids1, etype=etype)
            assert len(orig_eids1) == len(orig_eids2)
            assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
285
        parts.append(part_g)
286
287
288
        verify_graph_feats(
            hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
        )
289

290
291
        shuffled_labels.append(node_feats["n1/labels"])
        shuffled_elabels.append(edge_feats["r1/labels"])
292
293
    verify_hetero_graph(hg, parts)

294
295
296
    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
    shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0))
    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    orig_elabels = np.zeros(
        shuffled_elabels.shape, dtype=shuffled_elabels.dtype
    )
    orig_labels[F.asnumpy(orig_nids["n1"])] = shuffled_labels
    orig_elabels[F.asnumpy(orig_eids["r1"])] = shuffled_elabels
    assert np.all(orig_labels == F.asnumpy(hg.nodes["n1"].data["labels"]))
    assert np.all(orig_elabels == F.asnumpy(hg.edges["r1"].data["labels"]))


def check_partition(
    g,
    part_method,
    reshuffle,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
    g.ndata["labels"] = F.arange(0, g.number_of_nodes())
    g.ndata["feats"] = F.tensor(
        np.random.randn(g.number_of_nodes(), 10), F.float32
    )
    g.edata["feats"] = F.tensor(
        np.random.randn(g.number_of_edges(), 10), F.float32
    )
    g.update_all(fn.copy_src("feats", "msg"), fn.sum("msg", "h"))
    g.update_all(fn.copy_edge("feats", "msg"), fn.sum("msg", "eh"))
324
    num_hops = 2
Da Zheng's avatar
Da Zheng committed
325

326
327
328
329
330
331
332
333
334
335
336
337
    orig_nids, orig_eids = partition_graph(
        g,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        reshuffle=reshuffle,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
Da Zheng's avatar
Da Zheng committed
338
    part_sizes = []
339
340
    shuffled_labels = []
    shuffled_edata = []
341
    for i in range(num_parts):
342
        part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
343
344
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
345
        _verify_partition_data_types(part_g)
346
        _verify_partition_formats(part_g, graph_formats)
347
348
349
        if not load_feats:
            assert not node_feats
            assert not edge_feats
350
351
352
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
353
354
        if num_trainers_per_machine > 1:
            for ntype in g.ntypes:
355
                name = ntype + "/trainer_id"
356
                assert name in node_feats
357
358
359
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
360
361
362
                assert np.all(F.asnumpy(part_ids) == i)

            for etype in g.etypes:
363
                name = etype + "/trainer_id"
364
                assert name in edge_feats
365
366
367
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
368
                assert np.all(F.asnumpy(part_ids) == i)
369
370

        # Check the metadata
Da Zheng's avatar
Da Zheng committed
371
372
373
374
375
376
        assert gpb._num_nodes() == g.number_of_nodes()
        assert gpb._num_edges() == g.number_of_edges()

        assert gpb.num_partitions() == num_parts
        gpb_meta = gpb.metadata()
        assert len(gpb_meta) == num_parts
377
378
379
        assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"]
        assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"]
        part_sizes.append((gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]))
Da Zheng's avatar
Da Zheng committed
380

381
        nid = F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata["inner_node"])
382
        local_nid = gpb.nid2localnid(nid, i)
383
        assert F.dtype(local_nid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
384
        assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))
385
        eid = F.boolean_mask(part_g.edata[dgl.EID], part_g.edata["inner_edge"])
386
        local_eid = gpb.eid2localeid(eid, i)
387
        assert F.dtype(local_eid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
388
        assert np.all(F.asnumpy(local_eid) == np.arange(0, len(local_eid)))
389
390

        # Check the node map.
391
392
393
394
        local_nodes = F.boolean_mask(
            part_g.ndata[dgl.NID], part_g.ndata["inner_node"]
        )
        llocal_nodes = F.nonzero_1d(part_g.ndata["inner_node"])
395
        local_nodes1 = gpb.partid2nids(i)
396
        assert F.dtype(local_nodes1) in (F.int32, F.int64)
397
398
399
        assert np.all(
            np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(local_nodes1))
        )
400
        assert np.all(F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)))
401
402

        # Check the edge map.
403
404
405
406
        local_edges = F.boolean_mask(
            part_g.edata[dgl.EID], part_g.edata["inner_edge"]
        )
        llocal_edges = F.nonzero_1d(part_g.edata["inner_edge"])
407
        local_edges1 = gpb.partid2eids(i)
408
        assert F.dtype(local_edges1) in (F.int32, F.int64)
409
410
411
        assert np.all(
            np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1))
        )
412
        assert np.all(F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)))
413

414
415
416
417
418
419
420
421
422
423
424
425
        # Verify the mapping between the reshuffled IDs and the original IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        orig_src_ids = F.gather_row(orig_nids, part_src_ids)
        orig_dst_ids = F.gather_row(orig_nids, part_dst_ids)
        orig_eids1 = F.gather_row(orig_eids, part_eids)
        orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids)
        assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0]
        assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))

426
        if reshuffle:
427
428
            local_orig_nids = orig_nids[part_g.ndata[dgl.NID]]
            local_orig_eids = orig_eids[part_g.edata[dgl.EID]]
429
430
431
432
433
434
            part_g.ndata["feats"] = F.gather_row(
                g.ndata["feats"], local_orig_nids
            )
            part_g.edata["feats"] = F.gather_row(
                g.edata["feats"], local_orig_eids
            )
435
436
            local_nodes = orig_nids[local_nodes]
            local_edges = orig_eids[local_edges]
437
        else:
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            part_g.ndata["feats"] = F.gather_row(
                g.ndata["feats"], part_g.ndata[dgl.NID]
            )
            part_g.edata["feats"] = F.gather_row(
                g.edata["feats"], part_g.edata[dgl.NID]
            )

        part_g.update_all(fn.copy_src("feats", "msg"), fn.sum("msg", "h"))
        part_g.update_all(fn.copy_edge("feats", "msg"), fn.sum("msg", "eh"))
        assert F.allclose(
            F.gather_row(g.ndata["h"], local_nodes),
            F.gather_row(part_g.ndata["h"], llocal_nodes),
        )
        assert F.allclose(
            F.gather_row(g.ndata["eh"], local_nodes),
            F.gather_row(part_g.ndata["eh"], llocal_nodes),
        )

        for name in ["labels", "feats"]:
            assert "_N/" + name in node_feats
            assert node_feats["_N/" + name].shape[0] == len(local_nodes)
459
            true_feats = F.gather_row(g.ndata[name], local_nodes)
460
            ndata = F.gather_row(node_feats["_N/" + name], local_nid)
461
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata))
462
463
464
        for name in ["feats"]:
            assert "_E/" + name in edge_feats
            assert edge_feats["_E/" + name].shape[0] == len(local_edges)
465
            true_feats = F.gather_row(g.edata[name], local_edges)
466
            edata = F.gather_row(edge_feats["_E/" + name], local_eid)
467
468
469
470
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata))

        # This only works if node/edge IDs are shuffled.
        if reshuffle:
471
472
            shuffled_labels.append(node_feats["_N/labels"])
            shuffled_edata.append(edge_feats["_E/feats"])
473
474
475
476
477

    # Verify that we can reconstruct node/edge data for original IDs.
    if reshuffle:
        shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
        shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))
478
479
480
        orig_labels = np.zeros(
            shuffled_labels.shape, dtype=shuffled_labels.dtype
        )
481
482
483
        orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)
        orig_labels[F.asnumpy(orig_nids)] = shuffled_labels
        orig_edata[F.asnumpy(orig_eids)] = shuffled_edata
484
485
        assert np.all(orig_labels == F.asnumpy(g.ndata["labels"]))
        assert np.all(orig_edata == F.asnumpy(g.edata["feats"]))
486

Da Zheng's avatar
Da Zheng committed
487
488
489
490
491
492
493
494
    if reshuffle:
        node_map = []
        edge_map = []
        for i, (num_nodes, num_edges) in enumerate(part_sizes):
            node_map.append(np.ones(num_nodes) * i)
            edge_map.append(np.ones(num_edges) * i)
        node_map = np.concatenate(node_map)
        edge_map = np.concatenate(edge_map)
495
496
497
498
499
500
        nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))
        assert F.dtype(nid2pid) in (F.int32, F.int64)
        assert np.all(F.asnumpy(nid2pid) == node_map)
        eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))
        assert F.dtype(eid2pid) in (F.int32, F.int64)
        assert np.all(F.asnumpy(eid2pid) == edge_map)
Da Zheng's avatar
Da Zheng committed
501

502

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("reshuffle", [True, False])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("num_trainers_per_machine", [1, 4])
@pytest.mark.parametrize("load_feats", [True, False])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["coo", "csc"], ["coo", "csc", "csr"]]
)
def test_partition(
    part_method,
    reshuffle,
    num_parts,
    num_trainers_per_machine,
    load_feats,
    graph_formats,
):
    os.environ["DGL_DIST_DEBUG"] = "1"
    if part_method == "random" and num_parts > 1:
        num_trainers_per_machine = 1
    g = create_random_graph(1000)
    check_partition(
        g,
        part_method,
        reshuffle,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
532
    hg = create_random_hetero()
533
534
535
536
537
538
539
540
    check_hetero_partition(
        hg,
        part_method,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
541
    reset_envs()
Da Zheng's avatar
Da Zheng committed
542

543

544
545
546
547
548
549
550
551
def test_BasicPartitionBook():
    part_id = 0
    num_parts = 2
    node_map = np.random.choice(num_parts, 1000)
    edge_map = np.random.choice(num_parts, 5000)
    graph = dgl.rand_graph(1000, 5000)
    graph = dgl.node_subgraph(graph, F.arange(0, graph.num_nodes()))
    gpb = BasicPartitionBook(part_id, num_parts, node_map, edge_map, graph)
552
553
    c_etype = ("_N", "_E", "_N")
    assert gpb.etypes == ["_E"]
554
555
    assert gpb.canonical_etypes == [c_etype]

556
557
558
559
560
    node_policy = NodePartitionPolicy(gpb, "_N")
    assert node_policy.type_name == "_N"
    edge_policy = EdgePartitionPolicy(gpb, "_E")
    assert edge_policy.type_name == "_E"

561
562
563
564
565

def test_RangePartitionBook():
    part_id = 0
    num_parts = 2
    # homogeneous
566
567
568
569
    node_map = {"_N": F.tensor([[0, 1000], [1000, 2000]])}
    edge_map = {"_E": F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {"_N": 0}
    etypes = {"_E": 0}
570
    gpb = RangePartitionBook(
571
572
573
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
    assert gpb.etypes == ["_E"]
574
    assert gpb.canonical_etypes == [None]
575
    assert gpb._to_canonical_etype("_E") == "_E"
576

577
578
579
580
    node_policy = NodePartitionPolicy(gpb, "_N")
    assert node_policy.type_name == "_N"
    edge_policy = EdgePartitionPolicy(gpb, "_E")
    assert edge_policy.type_name == "_E"
581
582

    # heterogeneous, init via etype
583
584
585
586
587
588
589
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {"edge1": F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {"node1": 0, "node2": 1}
    etypes = {"edge1": 0}
590
    gpb = RangePartitionBook(
591
592
593
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
    assert gpb.etypes == ["edge1"]
594
    assert gpb.canonical_etypes == [None]
595
    assert gpb._to_canonical_etype("edge1") == "edge1"
596

597
598
599
600
    node_policy = NodePartitionPolicy(gpb, "node1")
    assert node_policy.type_name == "node1"
    edge_policy = EdgePartitionPolicy(gpb, "edge1")
    assert edge_policy.type_name == "edge1"
601
602

    # heterogeneous, init via canonical etype
603
604
605
606
607
608
609
610
611
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {
        ("node1", "edge1", "node2"): F.tensor([[0, 5000], [5000, 10000]])
    }
    ntypes = {"node1": 0, "node2": 1}
    etypes = {("node1", "edge1", "node2"): 0}
612
613
    c_etype = list(etypes.keys())[0]
    gpb = RangePartitionBook(
614
615
616
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
    assert gpb.etypes == ["edge1"]
617
    assert gpb.canonical_etypes == [c_etype]
618
    assert gpb._to_canonical_etype("edge1") == c_etype
619
620
621
    assert gpb._to_canonical_etype(c_etype) == c_etype
    expect_except = False
    try:
622
623
        gpb._to_canonical_etype(("node1", "edge2", "node2"))
    except dgl.DGLError:
624
625
626
627
        expect_except = True
    assert expect_except
    expect_except = False
    try:
628
629
        gpb._to_canonical_etype("edge2")
    except dgl.DGLError:
630
631
632
        expect_except = True
    assert expect_except

633
634
    node_policy = NodePartitionPolicy(gpb, "node1")
    assert node_policy.type_name == "node1"
635
636
637
    edge_policy = EdgePartitionPolicy(gpb, c_etype)
    assert edge_policy.type_name == c_etype

638
639
640
    data_name = HeteroDataName(False, "edge1", "edge1")
    assert data_name.get_type() == "edge1"
    data_name = HeteroDataName(False, c_etype, "edge1")
641
    assert data_name.get_type() == c_etype