test_partition.py 25.9 KB
Newer Older
1
import os
2
3

import backend as F
4
import torch as th
5
import dgl
6
import numpy as np
7
import pytest
8
from dgl import function as fn
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from dgl.distributed import (
    load_partition,
    load_partition_feats,
    partition_graph,
)
from dgl.distributed.graph_partition_book import (
    DEFAULT_ETYPE,
    DEFAULT_NTYPE,
    BasicPartitionBook,
    EdgePartitionPolicy,
    HeteroDataName,
    NodePartitionPolicy,
    RangePartitionBook,
    _etype_tuple_to_str,
)
24
25
from dgl.distributed.partition import (
    RESERVED_FIELD_DTYPE,
26
    _get_inner_edge_mask,
27
28
29
    _get_inner_node_mask,
)
from scipy import sparse as spsp
30
from utils import reset_envs
31

32
33
34
35
36
37
38

def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
39

40

41
42
43
44
def _verify_partition_formats(part_g, formats):
    # verify saved graph formats
    if formats is None:
        assert "coo" in part_g.formats()["created"]
45
    else:
46
47
        for format in formats:
            assert format in part_g.formats()["created"]
48
49


50
def create_random_graph(n):
51
52
53
    arr = (
        spsp.random(n, n, density=0.001, format="coo", random_state=100) != 0
    ).astype(np.int64)
54
    return dgl.from_scipy(arr)
55

56

57
def create_random_hetero():
58
    num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
59
60
61
62
63
64
    etypes = [
        ("n1", "r1", "n2"),
        ("n2", "r1", "n1"),
        ("n1", "r2", "n3"),
        ("n2", "r3", "n3"),
    ]
65
66
67
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
68
69
70
71
72
73
74
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
75
76
77
        edges[etype] = (arr.row, arr.col)
    return dgl.heterograph(edges, num_nodes)

78

79
def verify_hetero_graph(g, parts):
80
    num_nodes = {ntype: 0 for ntype in g.ntypes}
81
    num_edges = {etype: 0 for etype in g.canonical_etypes}
82
83
    for part in parts:
        assert len(g.ntypes) == len(F.unique(part.ndata[dgl.NTYPE]))
84
        assert len(g.canonical_etypes) == len(F.unique(part.edata[dgl.ETYPE]))
85
86
87
88
89
        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0)
            num_nodes[ntype] += num_inner_nodes
90
        for etype in g.canonical_etypes:
91
92
93
94
95
96
            etype_id = g.get_etype_id(etype)
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
            num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0)
            num_edges[etype] += num_inner_edges
    # Verify the number of nodes are correct.
    for ntype in g.ntypes:
97
98
99
100
101
        print(
            "node {}: {}, {}".format(
                ntype, g.number_of_nodes(ntype), num_nodes[ntype]
            )
        )
102
103
        assert g.number_of_nodes(ntype) == num_nodes[ntype]
    # Verify the number of edges are correct.
104
    for etype in g.canonical_etypes:
105
106
107
108
109
        print(
            "edge {}: {}, {}".format(
                etype, g.number_of_edges(etype), num_edges[etype]
            )
        )
110
111
        assert g.number_of_edges(etype) == num_edges[etype]

112
    nids = {ntype: [] for ntype in g.ntypes}
113
    eids = {etype: [] for etype in g.canonical_etypes}
114
    for part in parts:
115
        _, _, eid = part.edges(form="all")
116
117
        etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid)
        eid_type = F.gather_row(part.edata[dgl.EID], eid)
118
        for etype in g.canonical_etypes:
119
120
121
122
            etype_id = g.get_etype_id(etype)
            eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id))
            # Make sure edge Ids fall into a range.
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
123
124
125
126
127
128
            inner_eids = np.sort(
                F.asnumpy(F.boolean_mask(part.edata[dgl.EID], inner_edge_mask))
            )
            assert np.all(
                inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1)
            )
129
130
131
132
133
134

        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            # Make sure inner nodes have Ids fall into a range.
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
135
136
137
138
139
140
141
142
143
            assert np.all(
                F.asnumpy(
                    inner_nids
                    == F.arange(
                        F.as_scalar(inner_nids[0]),
                        F.as_scalar(inner_nids[-1]) + 1,
                    )
                )
            )
144
145
146
147
148
149
150
151
152
153
154
155
156
            nids[ntype].append(inner_nids)

    for ntype in nids:
        nids_type = F.cat(nids[ntype], 0)
        uniq_ids = F.unique(nids_type)
        # We should get all nodes.
        assert len(uniq_ids) == g.number_of_nodes(ntype)
    for etype in eids:
        eids_type = F.cat(eids[etype], 0)
        uniq_ids = F.unique(eids_type)
        assert len(uniq_ids) == g.number_of_edges(etype)
    # TODO(zhengda) this doesn't check 'part_id'

157
158
159
160

def verify_graph_feats(
    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids
):
161
162
    for ntype in g.ntypes:
        ntype_id = g.get_ntype_id(ntype)
163
        inner_node_mask = _get_inner_node_mask(part, ntype_id)
164
        inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
165
166
167
168
169
        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
        partid = gpb.nid2partid(inner_type_nids, ntype)
        assert np.all(F.asnumpy(ntype_ids) == ntype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

170
        orig_id = orig_nids[ntype][inner_type_nids]
171
172
        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)

173
        for name in g.nodes[ntype].data:
174
            if name in [dgl.NID, "inner_node"]:
175
176
                continue
            true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)
177
            ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids)
178
179
            assert np.all(F.asnumpy(ndata == true_feats))

180
    for etype in g.canonical_etypes:
181
182
        etype_id = g.get_etype_id(etype)
        inner_edge_mask = _get_inner_edge_mask(part, etype_id)
183
        inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask)
184
185
186
187
188
        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
        partid = gpb.eid2partid(inner_type_eids, etype)
        assert np.all(F.asnumpy(etype_ids) == etype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

189
        orig_id = orig_eids[etype][inner_type_eids]
190
191
192
        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)

        for name in g.edges[etype].data:
193
            if name in [dgl.EID, "inner_edge"]:
194
195
                continue
            true_feats = F.gather_row(g.edges[etype].data[name], orig_id)
196
197
198
            edata = F.gather_row(
                edge_feats[_etype_tuple_to_str(etype) + "/" + name], local_eids
            )
199
200
            assert np.all(F.asnumpy(edata == true_feats))

201
202
203
204
205
206
207
208
209

def check_hetero_partition(
    hg,
    part_method,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
210
211
212
213
214
    test_ntype = "n1"
    test_etype = ("n1", "r1", "n2")
    hg.nodes[test_ntype].data["labels"] = F.arange(0, hg.num_nodes(test_ntype))
    hg.nodes[test_ntype].data["feats"] = F.tensor(
        np.random.randn(hg.num_nodes(test_ntype), 10), F.float32
215
    )
216
217
    hg.edges[test_etype].data["feats"] = F.tensor(
        np.random.randn(hg.num_edges(test_etype), 10), F.float32
218
    )
219
    hg.edges[test_etype].data["labels"] = F.arange(0, hg.num_edges(test_etype))
220
221
    num_hops = 1

222
223
224
225
226
227
228
229
230
231
232
233
    orig_nids, orig_eids = partition_graph(
        hg,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        reshuffle=True,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
234
    assert len(orig_nids) == len(hg.ntypes)
235
    assert len(orig_eids) == len(hg.canonical_etypes)
236
237
    for ntype in hg.ntypes:
        assert len(orig_nids[ntype]) == hg.number_of_nodes(ntype)
238
    for etype in hg.canonical_etypes:
239
        assert len(orig_eids[etype]) == hg.number_of_edges(etype)
240
    parts = []
241
242
    shuffled_labels = []
    shuffled_elabels = []
243
    for i in range(num_parts):
244
        part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
245
246
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
247
        _verify_partition_data_types(part_g)
248
        _verify_partition_formats(part_g, graph_formats)
249
250
251
        if not load_feats:
            assert not node_feats
            assert not edge_feats
252
253
254
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
255
256
        if num_trainers_per_machine > 1:
            for ntype in hg.ntypes:
257
                name = ntype + "/trainer_id"
258
                assert name in node_feats
259
260
261
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
262
263
                assert np.all(F.asnumpy(part_ids) == i)

264
265
            for etype in hg.canonical_etypes:
                name = _etype_tuple_to_str(etype) + "/trainer_id"
266
                assert name in edge_feats
267
268
269
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
270
                assert np.all(F.asnumpy(part_ids) == i)
271
272
273
274
275
276
277
278
279
280
281
282
        # Verify the mapping between the reshuffled IDs and the original IDs.
        # These are partition-local IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        # These are reshuffled global homogeneous IDs.
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        # These are reshuffled per-type IDs.
        src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)
        dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)
        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)
        # These are original per-type IDs.
283
        for etype_id, etype in enumerate(hg.canonical_etypes):
284
            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)
285
286
287
            src_ntype_ids1 = F.boolean_mask(
                src_ntype_ids, etype_ids == etype_id
            )
288
            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)
289
290
291
            dst_ntype_ids1 = F.boolean_mask(
                dst_ntype_ids, etype_ids == etype_id
            )
292
293
294
295
296
297
298
299
300
301
302
            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)
            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))
            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))
            src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])]
            dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])]
            orig_src_ids1 = F.gather_row(orig_nids[src_ntype], part_src_ids1)
            orig_dst_ids1 = F.gather_row(orig_nids[dst_ntype], part_dst_ids1)
            orig_eids1 = F.gather_row(orig_eids[etype], part_eids1)
            orig_eids2 = hg.edge_ids(orig_src_ids1, orig_dst_ids1, etype=etype)
            assert len(orig_eids1) == len(orig_eids2)
            assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
303
        parts.append(part_g)
304
305
306
        verify_graph_feats(
            hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
        )
307

308
309
310
311
        shuffled_labels.append(node_feats[test_ntype + "/labels"])
        shuffled_elabels.append(
            edge_feats[_etype_tuple_to_str(test_etype) + "/labels"]
        )
312
313
    verify_hetero_graph(hg, parts)

314
315
316
    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
    shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0))
    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
317
318
319
    orig_elabels = np.zeros(
        shuffled_elabels.shape, dtype=shuffled_elabels.dtype
    )
320
321
322
323
324
325
    orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels
    orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels
    assert np.all(orig_labels == F.asnumpy(hg.nodes[test_ntype].data["labels"]))
    assert np.all(
        orig_elabels == F.asnumpy(hg.edges[test_etype].data["labels"])
    )
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343


def check_partition(
    g,
    part_method,
    reshuffle,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
    g.ndata["labels"] = F.arange(0, g.number_of_nodes())
    g.ndata["feats"] = F.tensor(
        np.random.randn(g.number_of_nodes(), 10), F.float32
    )
    g.edata["feats"] = F.tensor(
        np.random.randn(g.number_of_edges(), 10), F.float32
    )
344
345
    g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h"))
    g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh"))
346
    num_hops = 2
Da Zheng's avatar
Da Zheng committed
347

348
349
350
351
352
353
354
355
356
357
358
359
    orig_nids, orig_eids = partition_graph(
        g,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        reshuffle=reshuffle,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
Da Zheng's avatar
Da Zheng committed
360
    part_sizes = []
361
362
    shuffled_labels = []
    shuffled_edata = []
363
    for i in range(num_parts):
364
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
365
366
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
367
        _verify_partition_data_types(part_g)
368
        _verify_partition_formats(part_g, graph_formats)
369
370
371
        if not load_feats:
            assert not node_feats
            assert not edge_feats
372
373
374
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
375
376
        if num_trainers_per_machine > 1:
            for ntype in g.ntypes:
377
                name = ntype + "/trainer_id"
378
                assert name in node_feats
379
380
381
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
382
383
                assert np.all(F.asnumpy(part_ids) == i)

384
385
            for etype in g.canonical_etypes:
                name = _etype_tuple_to_str(etype) + "/trainer_id"
386
                assert name in edge_feats
387
388
389
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
390
                assert np.all(F.asnumpy(part_ids) == i)
391
392

        # Check the metadata
Da Zheng's avatar
Da Zheng committed
393
394
395
396
397
398
        assert gpb._num_nodes() == g.number_of_nodes()
        assert gpb._num_edges() == g.number_of_edges()

        assert gpb.num_partitions() == num_parts
        gpb_meta = gpb.metadata()
        assert len(gpb_meta) == num_parts
399
400
401
        assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"]
        assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"]
        part_sizes.append((gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]))
Da Zheng's avatar
Da Zheng committed
402

403
        nid = F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata["inner_node"])
404
        local_nid = gpb.nid2localnid(nid, i)
405
        assert F.dtype(local_nid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
406
        assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))
407
        eid = F.boolean_mask(part_g.edata[dgl.EID], part_g.edata["inner_edge"])
408
        local_eid = gpb.eid2localeid(eid, i)
409
        assert F.dtype(local_eid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
410
        assert np.all(F.asnumpy(local_eid) == np.arange(0, len(local_eid)))
411
412

        # Check the node map.
413
414
415
416
        local_nodes = F.boolean_mask(
            part_g.ndata[dgl.NID], part_g.ndata["inner_node"]
        )
        llocal_nodes = F.nonzero_1d(part_g.ndata["inner_node"])
417
        local_nodes1 = gpb.partid2nids(i)
418
        assert F.dtype(local_nodes1) in (F.int32, F.int64)
419
420
421
        assert np.all(
            np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(local_nodes1))
        )
422
        assert np.all(F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)))
423
424

        # Check the edge map.
425
426
427
428
        local_edges = F.boolean_mask(
            part_g.edata[dgl.EID], part_g.edata["inner_edge"]
        )
        llocal_edges = F.nonzero_1d(part_g.edata["inner_edge"])
429
        local_edges1 = gpb.partid2eids(i)
430
        assert F.dtype(local_edges1) in (F.int32, F.int64)
431
432
433
        assert np.all(
            np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1))
        )
434
        assert np.all(F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)))
435

436
437
438
439
440
441
442
443
444
445
446
447
        # Verify the mapping between the reshuffled IDs and the original IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        orig_src_ids = F.gather_row(orig_nids, part_src_ids)
        orig_dst_ids = F.gather_row(orig_nids, part_dst_ids)
        orig_eids1 = F.gather_row(orig_eids, part_eids)
        orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids)
        assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0]
        assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))

448
        if reshuffle:
449
450
            local_orig_nids = orig_nids[part_g.ndata[dgl.NID]]
            local_orig_eids = orig_eids[part_g.edata[dgl.EID]]
451
452
453
454
455
456
            part_g.ndata["feats"] = F.gather_row(
                g.ndata["feats"], local_orig_nids
            )
            part_g.edata["feats"] = F.gather_row(
                g.edata["feats"], local_orig_eids
            )
457
458
            local_nodes = orig_nids[local_nodes]
            local_edges = orig_eids[local_edges]
459
        else:
460
461
462
463
464
465
466
            part_g.ndata["feats"] = F.gather_row(
                g.ndata["feats"], part_g.ndata[dgl.NID]
            )
            part_g.edata["feats"] = F.gather_row(
                g.edata["feats"], part_g.edata[dgl.NID]
            )

467
468
        part_g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h"))
        part_g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh"))
469
470
471
472
473
474
475
476
477
478
479
480
        assert F.allclose(
            F.gather_row(g.ndata["h"], local_nodes),
            F.gather_row(part_g.ndata["h"], llocal_nodes),
        )
        assert F.allclose(
            F.gather_row(g.ndata["eh"], local_nodes),
            F.gather_row(part_g.ndata["eh"], llocal_nodes),
        )

        for name in ["labels", "feats"]:
            assert "_N/" + name in node_feats
            assert node_feats["_N/" + name].shape[0] == len(local_nodes)
481
            true_feats = F.gather_row(g.ndata[name], local_nodes)
482
            ndata = F.gather_row(node_feats["_N/" + name], local_nid)
483
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata))
484
        for name in ["feats"]:
485
486
487
            efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name
            assert efeat_name in edge_feats
            assert edge_feats[efeat_name].shape[0] == len(local_edges)
488
            true_feats = F.gather_row(g.edata[name], local_edges)
489
            edata = F.gather_row(edge_feats[efeat_name], local_eid)
490
491
492
493
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata))

        # This only works if node/edge IDs are shuffled.
        if reshuffle:
494
            shuffled_labels.append(node_feats["_N/labels"])
495
            shuffled_edata.append(edge_feats["_N:_E:_N/feats"])
496
497
498
499
500

    # Verify that we can reconstruct node/edge data for original IDs.
    if reshuffle:
        shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
        shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))
501
502
503
        orig_labels = np.zeros(
            shuffled_labels.shape, dtype=shuffled_labels.dtype
        )
504
505
506
        orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)
        orig_labels[F.asnumpy(orig_nids)] = shuffled_labels
        orig_edata[F.asnumpy(orig_eids)] = shuffled_edata
507
508
        assert np.all(orig_labels == F.asnumpy(g.ndata["labels"]))
        assert np.all(orig_edata == F.asnumpy(g.edata["feats"]))
509

Da Zheng's avatar
Da Zheng committed
510
511
512
513
514
515
516
517
    if reshuffle:
        node_map = []
        edge_map = []
        for i, (num_nodes, num_edges) in enumerate(part_sizes):
            node_map.append(np.ones(num_nodes) * i)
            edge_map.append(np.ones(num_edges) * i)
        node_map = np.concatenate(node_map)
        edge_map = np.concatenate(edge_map)
518
519
520
521
522
523
        nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))
        assert F.dtype(nid2pid) in (F.int32, F.int64)
        assert np.all(F.asnumpy(nid2pid) == node_map)
        eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))
        assert F.dtype(eid2pid) in (F.int32, F.int64)
        assert np.all(F.asnumpy(eid2pid) == edge_map)
Da Zheng's avatar
Da Zheng committed
524

525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("reshuffle", [True, False])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("num_trainers_per_machine", [1, 4])
@pytest.mark.parametrize("load_feats", [True, False])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["coo", "csc"], ["coo", "csc", "csr"]]
)
def test_partition(
    part_method,
    reshuffle,
    num_parts,
    num_trainers_per_machine,
    load_feats,
    graph_formats,
):
    os.environ["DGL_DIST_DEBUG"] = "1"
    if part_method == "random" and num_parts > 1:
        num_trainers_per_machine = 1
    g = create_random_graph(1000)
    check_partition(
        g,
        part_method,
        reshuffle,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
555
    hg = create_random_hetero()
556
557
558
559
560
561
562
563
    check_hetero_partition(
        hg,
        part_method,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
564
    reset_envs()
Da Zheng's avatar
Da Zheng committed
565

566

567
568
569
570
571
572
573
574
def test_BasicPartitionBook():
    part_id = 0
    num_parts = 2
    node_map = np.random.choice(num_parts, 1000)
    edge_map = np.random.choice(num_parts, 5000)
    graph = dgl.rand_graph(1000, 5000)
    graph = dgl.node_subgraph(graph, F.arange(0, graph.num_nodes()))
    gpb = BasicPartitionBook(part_id, num_parts, node_map, edge_map, graph)
575
576
    c_etype = ("_N", "_E", "_N")
    assert gpb.etypes == ["_E"]
577
578
    assert gpb.canonical_etypes == [c_etype]

579
580
    node_policy = NodePartitionPolicy(gpb, "_N")
    assert node_policy.type_name == "_N"
581
582
583
584
585
586
587
588
    expect_except = False
    try:
        edge_policy = EdgePartitionPolicy(gpb, "_E")
    except AssertionError:
        expect_except = True
    assert expect_except
    edge_policy = EdgePartitionPolicy(gpb, c_etype)
    assert edge_policy.type_name == c_etype
589

590
591

def test_RangePartitionBook():
592
    part_id = 1
593
    num_parts = 2
594

595
    # homogeneous
596
597
598
599
    node_map = {DEFAULT_NTYPE: F.tensor([[0, 1000], [1000, 2000]])}
    edge_map = {DEFAULT_ETYPE: F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {DEFAULT_NTYPE: 0}
    etypes = {DEFAULT_ETYPE: 0}
600
    gpb = RangePartitionBook(
601
602
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
603
604
605
    assert gpb.etypes == [DEFAULT_ETYPE[1]]
    assert gpb.canonical_etypes == [DEFAULT_ETYPE]
    assert gpb.to_canonical_etype(DEFAULT_ETYPE[1]) == DEFAULT_ETYPE
606

607
608
609
610
    node_policy = NodePartitionPolicy(gpb, DEFAULT_NTYPE)
    assert node_policy.type_name == DEFAULT_NTYPE
    edge_policy = EdgePartitionPolicy(gpb, DEFAULT_ETYPE)
    assert edge_policy.type_name == DEFAULT_ETYPE
611

612
    # Init via etype is not supported
613
614
615
616
617
618
619
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {"edge1": F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {"node1": 0, "node2": 1}
    etypes = {"edge1": 0}
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    expect_except = False
    try:
        RangePartitionBook(
            part_id, num_parts, node_map, edge_map, ntypes, etypes
        )
    except AssertionError:
        expect_except = True
    assert expect_except
    expect_except = False
    try:
        EdgePartitionPolicy(gpb, "edge1")
    except AssertionError:
        expect_except = True
    assert expect_except
634
635

    # heterogeneous, init via canonical etype
636
637
638
639
640
641
642
643
644
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {
        ("node1", "edge1", "node2"): F.tensor([[0, 5000], [5000, 10000]])
    }
    ntypes = {"node1": 0, "node2": 1}
    etypes = {("node1", "edge1", "node2"): 0}
645
646
    c_etype = list(etypes.keys())[0]
    gpb = RangePartitionBook(
647
648
649
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
    assert gpb.etypes == ["edge1"]
650
    assert gpb.canonical_etypes == [c_etype]
651
652
    assert gpb.to_canonical_etype("edge1") == c_etype
    assert gpb.to_canonical_etype(c_etype) == c_etype
653
654
    expect_except = False
    try:
655
656
        gpb.to_canonical_etype(("node1", "edge2", "node2"))
    except:
657
658
659
660
        expect_except = True
    assert expect_except
    expect_except = False
    try:
661
662
        gpb.to_canonical_etype("edge2")
    except:
663
664
665
        expect_except = True
    assert expect_except

666
    # NodePartitionPolicy
667
668
    node_policy = NodePartitionPolicy(gpb, "node1")
    assert node_policy.type_name == "node1"
669
670
671
672
673
674
675
676
677
678
679
680
    assert node_policy.policy_str == "node~node1"
    assert node_policy.part_id == part_id
    assert node_policy.is_node
    assert node_policy.get_data_name('x').is_node()
    local_ids = th.arange(0, 1000)
    global_ids = local_ids + 1000
    assert th.equal(node_policy.to_local(global_ids), local_ids)
    assert th.all(node_policy.to_partid(global_ids) == part_id)
    assert node_policy.get_part_size() == 1000
    assert node_policy.get_size() == 2000

    # EdgePartitionPolicy
681
682
    edge_policy = EdgePartitionPolicy(gpb, c_etype)
    assert edge_policy.type_name == c_etype
683
684
685
686
687
688
689
690
691
692
    assert edge_policy.policy_str == "edge~node1:edge1:node2"
    assert edge_policy.part_id == part_id
    assert not edge_policy.is_node
    assert not edge_policy.get_data_name('x').is_node()
    local_ids = th.arange(0, 5000)
    global_ids = local_ids + 5000
    assert th.equal(edge_policy.to_local(global_ids), local_ids)
    assert th.all(edge_policy.to_partid(global_ids) == part_id)
    assert edge_policy.get_part_size() == 5000
    assert edge_policy.get_size() == 10000
693

694
695
696
697
698
699
700
    expect_except = False
    try:
        HeteroDataName(False, "edge1", "feat")
    except:
        expect_except = True
    assert expect_except
    data_name = HeteroDataName(False, c_etype, "feat")
701
    assert data_name.get_type() == c_etype