test_partition.py 25.2 KB
Newer Older
1
import json
2
import os
3
import tempfile
4
5
6

import backend as F
import dgl
7
import numpy as np
8
import pytest
9
import torch as th
10
from dgl import function as fn
11
12
from dgl.distributed import (
    load_partition,
13
    load_partition_book,
14
15
16
17
    load_partition_feats,
    partition_graph,
)
from dgl.distributed.graph_partition_book import (
18
    _etype_tuple_to_str,
19
20
21
22
23
24
25
    DEFAULT_ETYPE,
    DEFAULT_NTYPE,
    EdgePartitionPolicy,
    HeteroDataName,
    NodePartitionPolicy,
    RangePartitionBook,
)
26
from dgl.distributed.partition import (
27
    _get_inner_edge_mask,
28
    _get_inner_node_mask,
29
    RESERVED_FIELD_DTYPE,
30
31
)
from scipy import sparse as spsp
32
from utils import reset_envs
33

34
35
36
37
38
39
40

def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
41

42

43
44
45
46
def _verify_partition_formats(part_g, formats):
    # verify saved graph formats
    if formats is None:
        assert "coo" in part_g.formats()["created"]
47
    else:
48
49
        for format in formats:
            assert format in part_g.formats()["created"]
50
51


52
def create_random_graph(n):
53
54
55
    arr = (
        spsp.random(n, n, density=0.001, format="coo", random_state=100) != 0
    ).astype(np.int64)
56
    return dgl.from_scipy(arr)
57

58

59
def create_random_hetero():
60
    num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
61
62
63
64
65
66
    etypes = [
        ("n1", "r1", "n2"),
        ("n2", "r1", "n1"),
        ("n1", "r2", "n3"),
        ("n2", "r3", "n3"),
    ]
67
68
69
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
70
71
72
73
74
75
76
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
77
78
79
        edges[etype] = (arr.row, arr.col)
    return dgl.heterograph(edges, num_nodes)

80

81
def verify_hetero_graph(g, parts):
82
    num_nodes = {ntype: 0 for ntype in g.ntypes}
83
    num_edges = {etype: 0 for etype in g.canonical_etypes}
84
85
    for part in parts:
        assert len(g.ntypes) == len(F.unique(part.ndata[dgl.NTYPE]))
86
        assert len(g.canonical_etypes) == len(F.unique(part.edata[dgl.ETYPE]))
87
88
89
90
91
        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0)
            num_nodes[ntype] += num_inner_nodes
92
        for etype in g.canonical_etypes:
93
94
95
96
97
98
            etype_id = g.get_etype_id(etype)
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
            num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0)
            num_edges[etype] += num_inner_edges
    # Verify the number of nodes are correct.
    for ntype in g.ntypes:
99
100
101
102
103
        print(
            "node {}: {}, {}".format(
                ntype, g.number_of_nodes(ntype), num_nodes[ntype]
            )
        )
104
105
        assert g.number_of_nodes(ntype) == num_nodes[ntype]
    # Verify the number of edges are correct.
106
    for etype in g.canonical_etypes:
107
108
109
110
111
        print(
            "edge {}: {}, {}".format(
                etype, g.number_of_edges(etype), num_edges[etype]
            )
        )
112
113
        assert g.number_of_edges(etype) == num_edges[etype]

114
    nids = {ntype: [] for ntype in g.ntypes}
115
    eids = {etype: [] for etype in g.canonical_etypes}
116
    for part in parts:
117
        _, _, eid = part.edges(form="all")
118
119
        etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid)
        eid_type = F.gather_row(part.edata[dgl.EID], eid)
120
        for etype in g.canonical_etypes:
121
122
123
124
            etype_id = g.get_etype_id(etype)
            eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id))
            # Make sure edge Ids fall into a range.
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
125
126
127
128
129
130
            inner_eids = np.sort(
                F.asnumpy(F.boolean_mask(part.edata[dgl.EID], inner_edge_mask))
            )
            assert np.all(
                inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1)
            )
131
132
133
134
135
136

        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            # Make sure inner nodes have Ids fall into a range.
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
137
138
139
140
141
142
143
144
145
            assert np.all(
                F.asnumpy(
                    inner_nids
                    == F.arange(
                        F.as_scalar(inner_nids[0]),
                        F.as_scalar(inner_nids[-1]) + 1,
                    )
                )
            )
146
147
148
149
150
151
152
153
154
155
156
157
158
            nids[ntype].append(inner_nids)

    for ntype in nids:
        nids_type = F.cat(nids[ntype], 0)
        uniq_ids = F.unique(nids_type)
        # We should get all nodes.
        assert len(uniq_ids) == g.number_of_nodes(ntype)
    for etype in eids:
        eids_type = F.cat(eids[etype], 0)
        uniq_ids = F.unique(eids_type)
        assert len(uniq_ids) == g.number_of_edges(etype)
    # TODO(zhengda) this doesn't check 'part_id'

159
160
161
162

def verify_graph_feats(
    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids
):
163
164
    for ntype in g.ntypes:
        ntype_id = g.get_ntype_id(ntype)
165
        inner_node_mask = _get_inner_node_mask(part, ntype_id)
166
        inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
167
168
169
170
171
        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
        partid = gpb.nid2partid(inner_type_nids, ntype)
        assert np.all(F.asnumpy(ntype_ids) == ntype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

172
        orig_id = orig_nids[ntype][inner_type_nids]
173
174
        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)

175
        for name in g.nodes[ntype].data:
176
            if name in [dgl.NID, "inner_node"]:
177
178
                continue
            true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)
179
            ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids)
180
181
            assert np.all(F.asnumpy(ndata == true_feats))

182
    for etype in g.canonical_etypes:
183
184
        etype_id = g.get_etype_id(etype)
        inner_edge_mask = _get_inner_edge_mask(part, etype_id)
185
        inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask)
186
187
188
189
190
        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
        partid = gpb.eid2partid(inner_type_eids, etype)
        assert np.all(F.asnumpy(etype_ids) == etype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

191
        orig_id = orig_eids[etype][inner_type_eids]
192
193
194
        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)

        for name in g.edges[etype].data:
195
            if name in [dgl.EID, "inner_edge"]:
196
197
                continue
            true_feats = F.gather_row(g.edges[etype].data[name], orig_id)
198
199
200
            edata = F.gather_row(
                edge_feats[_etype_tuple_to_str(etype) + "/" + name], local_eids
            )
201
202
            assert np.all(F.asnumpy(edata == true_feats))

203
204
205
206
207
208
209
210
211

def check_hetero_partition(
    hg,
    part_method,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
212
213
214
215
216
    test_ntype = "n1"
    test_etype = ("n1", "r1", "n2")
    hg.nodes[test_ntype].data["labels"] = F.arange(0, hg.num_nodes(test_ntype))
    hg.nodes[test_ntype].data["feats"] = F.tensor(
        np.random.randn(hg.num_nodes(test_ntype), 10), F.float32
217
    )
218
219
    hg.edges[test_etype].data["feats"] = F.tensor(
        np.random.randn(hg.num_edges(test_etype), 10), F.float32
220
    )
221
    hg.edges[test_etype].data["labels"] = F.arange(0, hg.num_edges(test_etype))
222
223
    num_hops = 1

224
225
226
227
228
229
230
231
232
233
234
    orig_nids, orig_eids = partition_graph(
        hg,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
235
    assert len(orig_nids) == len(hg.ntypes)
236
    assert len(orig_eids) == len(hg.canonical_etypes)
237
238
    for ntype in hg.ntypes:
        assert len(orig_nids[ntype]) == hg.number_of_nodes(ntype)
239
    for etype in hg.canonical_etypes:
240
        assert len(orig_eids[etype]) == hg.number_of_edges(etype)
241
    parts = []
242
243
    shuffled_labels = []
    shuffled_elabels = []
244
    for i in range(num_parts):
245
        part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
246
247
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
248
        _verify_partition_data_types(part_g)
249
        _verify_partition_formats(part_g, graph_formats)
250
251
252
        if not load_feats:
            assert not node_feats
            assert not edge_feats
253
254
255
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
256
257
        if num_trainers_per_machine > 1:
            for ntype in hg.ntypes:
258
                name = ntype + "/trainer_id"
259
                assert name in node_feats
260
261
262
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
263
264
                assert np.all(F.asnumpy(part_ids) == i)

265
266
            for etype in hg.canonical_etypes:
                name = _etype_tuple_to_str(etype) + "/trainer_id"
267
                assert name in edge_feats
268
269
270
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
271
                assert np.all(F.asnumpy(part_ids) == i)
272
273
274
275
276
277
278
279
280
281
282
283
        # Verify the mapping between the reshuffled IDs and the original IDs.
        # These are partition-local IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        # These are reshuffled global homogeneous IDs.
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        # These are reshuffled per-type IDs.
        src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)
        dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)
        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)
        # These are original per-type IDs.
284
        for etype_id, etype in enumerate(hg.canonical_etypes):
285
            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)
286
287
288
            src_ntype_ids1 = F.boolean_mask(
                src_ntype_ids, etype_ids == etype_id
            )
289
            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)
290
291
292
            dst_ntype_ids1 = F.boolean_mask(
                dst_ntype_ids, etype_ids == etype_id
            )
293
294
295
296
297
298
299
300
301
302
303
            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)
            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))
            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))
            src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])]
            dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])]
            orig_src_ids1 = F.gather_row(orig_nids[src_ntype], part_src_ids1)
            orig_dst_ids1 = F.gather_row(orig_nids[dst_ntype], part_dst_ids1)
            orig_eids1 = F.gather_row(orig_eids[etype], part_eids1)
            orig_eids2 = hg.edge_ids(orig_src_ids1, orig_dst_ids1, etype=etype)
            assert len(orig_eids1) == len(orig_eids2)
            assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
304
        parts.append(part_g)
305
306
307
        verify_graph_feats(
            hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
        )
308

309
310
311
312
        shuffled_labels.append(node_feats[test_ntype + "/labels"])
        shuffled_elabels.append(
            edge_feats[_etype_tuple_to_str(test_etype) + "/labels"]
        )
313
314
    verify_hetero_graph(hg, parts)

315
316
317
    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
    shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0))
    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
318
319
320
    orig_elabels = np.zeros(
        shuffled_elabels.shape, dtype=shuffled_elabels.dtype
    )
321
322
323
324
325
326
    orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels
    orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels
    assert np.all(orig_labels == F.asnumpy(hg.nodes[test_ntype].data["labels"]))
    assert np.all(
        orig_elabels == F.asnumpy(hg.edges[test_etype].data["labels"])
    )
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343


def check_partition(
    g,
    part_method,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
    g.ndata["labels"] = F.arange(0, g.number_of_nodes())
    g.ndata["feats"] = F.tensor(
        np.random.randn(g.number_of_nodes(), 10), F.float32
    )
    g.edata["feats"] = F.tensor(
        np.random.randn(g.number_of_edges(), 10), F.float32
    )
344
345
    g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h"))
    g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh"))
346
    num_hops = 2
Da Zheng's avatar
Da Zheng committed
347

348
349
350
351
352
353
354
355
356
357
358
    orig_nids, orig_eids = partition_graph(
        g,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
Da Zheng's avatar
Da Zheng committed
359
    part_sizes = []
360
361
    shuffled_labels = []
    shuffled_edata = []
362
    for i in range(num_parts):
363
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
364
365
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
366
        _verify_partition_data_types(part_g)
367
        _verify_partition_formats(part_g, graph_formats)
368
369
370
        if not load_feats:
            assert not node_feats
            assert not edge_feats
371
372
373
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
374
375
        if num_trainers_per_machine > 1:
            for ntype in g.ntypes:
376
                name = ntype + "/trainer_id"
377
                assert name in node_feats
378
379
380
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
381
382
                assert np.all(F.asnumpy(part_ids) == i)

383
384
            for etype in g.canonical_etypes:
                name = _etype_tuple_to_str(etype) + "/trainer_id"
385
                assert name in edge_feats
386
387
388
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
389
                assert np.all(F.asnumpy(part_ids) == i)
390
391

        # Check the metadata
Da Zheng's avatar
Da Zheng committed
392
393
394
395
396
397
        assert gpb._num_nodes() == g.number_of_nodes()
        assert gpb._num_edges() == g.number_of_edges()

        assert gpb.num_partitions() == num_parts
        gpb_meta = gpb.metadata()
        assert len(gpb_meta) == num_parts
398
399
400
        assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"]
        assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"]
        part_sizes.append((gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]))
Da Zheng's avatar
Da Zheng committed
401

402
        nid = F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata["inner_node"])
403
        local_nid = gpb.nid2localnid(nid, i)
404
        assert F.dtype(local_nid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
405
        assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))
406
        eid = F.boolean_mask(part_g.edata[dgl.EID], part_g.edata["inner_edge"])
407
        local_eid = gpb.eid2localeid(eid, i)
408
        assert F.dtype(local_eid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
409
        assert np.all(F.asnumpy(local_eid) == np.arange(0, len(local_eid)))
410
411

        # Check the node map.
412
413
414
415
        local_nodes = F.boolean_mask(
            part_g.ndata[dgl.NID], part_g.ndata["inner_node"]
        )
        llocal_nodes = F.nonzero_1d(part_g.ndata["inner_node"])
416
        local_nodes1 = gpb.partid2nids(i)
417
        assert F.dtype(local_nodes1) in (F.int32, F.int64)
418
419
420
        assert np.all(
            np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(local_nodes1))
        )
421
        assert np.all(F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)))
422
423

        # Check the edge map.
424
425
426
427
        local_edges = F.boolean_mask(
            part_g.edata[dgl.EID], part_g.edata["inner_edge"]
        )
        llocal_edges = F.nonzero_1d(part_g.edata["inner_edge"])
428
        local_edges1 = gpb.partid2eids(i)
429
        assert F.dtype(local_edges1) in (F.int32, F.int64)
430
431
432
        assert np.all(
            np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1))
        )
433
        assert np.all(F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)))
434

435
436
437
438
439
440
441
442
443
444
445
446
        # Verify the mapping between the reshuffled IDs and the original IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        orig_src_ids = F.gather_row(orig_nids, part_src_ids)
        orig_dst_ids = F.gather_row(orig_nids, part_dst_ids)
        orig_eids1 = F.gather_row(orig_eids, part_eids)
        orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids)
        assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0]
        assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))

447
448
        local_orig_nids = orig_nids[part_g.ndata[dgl.NID]]
        local_orig_eids = orig_eids[part_g.edata[dgl.EID]]
449
450
        part_g.ndata["feats"] = F.gather_row(g.ndata["feats"], local_orig_nids)
        part_g.edata["feats"] = F.gather_row(g.edata["feats"], local_orig_eids)
451
452
        local_nodes = orig_nids[local_nodes]
        local_edges = orig_eids[local_edges]
453

454
455
        part_g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h"))
        part_g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh"))
456
457
458
459
460
461
462
463
464
465
466
467
        assert F.allclose(
            F.gather_row(g.ndata["h"], local_nodes),
            F.gather_row(part_g.ndata["h"], llocal_nodes),
        )
        assert F.allclose(
            F.gather_row(g.ndata["eh"], local_nodes),
            F.gather_row(part_g.ndata["eh"], llocal_nodes),
        )

        for name in ["labels", "feats"]:
            assert "_N/" + name in node_feats
            assert node_feats["_N/" + name].shape[0] == len(local_nodes)
468
            true_feats = F.gather_row(g.ndata[name], local_nodes)
469
            ndata = F.gather_row(node_feats["_N/" + name], local_nid)
470
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata))
471
        for name in ["feats"]:
472
473
474
            efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name
            assert efeat_name in edge_feats
            assert edge_feats[efeat_name].shape[0] == len(local_edges)
475
            true_feats = F.gather_row(g.edata[name], local_edges)
476
            edata = F.gather_row(edge_feats[efeat_name], local_eid)
477
478
479
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata))

        # This only works if node/edge IDs are shuffled.
480
481
        shuffled_labels.append(node_feats["_N/labels"])
        shuffled_edata.append(edge_feats["_N:_E:_N/feats"])
482
483

    # Verify that we can reconstruct node/edge data for original IDs.
484
485
    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
    shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))
486
    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)
    orig_labels[F.asnumpy(orig_nids)] = shuffled_labels
    orig_edata[F.asnumpy(orig_eids)] = shuffled_edata
    assert np.all(orig_labels == F.asnumpy(g.ndata["labels"]))
    assert np.all(orig_edata == F.asnumpy(g.edata["feats"]))

    node_map = []
    edge_map = []
    for i, (num_nodes, num_edges) in enumerate(part_sizes):
        node_map.append(np.ones(num_nodes) * i)
        edge_map.append(np.ones(num_edges) * i)
    node_map = np.concatenate(node_map)
    edge_map = np.concatenate(edge_map)
    nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))
    assert F.dtype(nid2pid) in (F.int32, F.int64)
    assert np.all(F.asnumpy(nid2pid) == node_map)
    eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))
    assert F.dtype(eid2pid) in (F.int32, F.int64)
    assert np.all(F.asnumpy(eid2pid) == edge_map)
Da Zheng's avatar
Da Zheng committed
506

507

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("num_trainers_per_machine", [1, 4])
@pytest.mark.parametrize("load_feats", [True, False])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["coo", "csc"], ["coo", "csc", "csr"]]
)
def test_partition(
    part_method,
    num_parts,
    num_trainers_per_machine,
    load_feats,
    graph_formats,
):
    os.environ["DGL_DIST_DEBUG"] = "1"
    if part_method == "random" and num_parts > 1:
        num_trainers_per_machine = 1
    g = create_random_graph(1000)
    check_partition(
        g,
        part_method,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
534
    hg = create_random_hetero()
535
536
537
538
539
540
541
542
    check_hetero_partition(
        hg,
        part_method,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
543
    reset_envs()
Da Zheng's avatar
Da Zheng committed
544

545

546
def test_RangePartitionBook():
547
    part_id = 1
548
    num_parts = 2
549

550
    # homogeneous
551
552
553
554
    node_map = {DEFAULT_NTYPE: F.tensor([[0, 1000], [1000, 2000]])}
    edge_map = {DEFAULT_ETYPE: F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {DEFAULT_NTYPE: 0}
    etypes = {DEFAULT_ETYPE: 0}
555
    gpb = RangePartitionBook(
556
557
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
558
559
560
    assert gpb.etypes == [DEFAULT_ETYPE[1]]
    assert gpb.canonical_etypes == [DEFAULT_ETYPE]
    assert gpb.to_canonical_etype(DEFAULT_ETYPE[1]) == DEFAULT_ETYPE
561

562
563
564
565
    node_policy = NodePartitionPolicy(gpb, DEFAULT_NTYPE)
    assert node_policy.type_name == DEFAULT_NTYPE
    edge_policy = EdgePartitionPolicy(gpb, DEFAULT_ETYPE)
    assert edge_policy.type_name == DEFAULT_ETYPE
566

567
    # Init via etype is not supported
568
569
570
571
572
573
574
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {"edge1": F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {"node1": 0, "node2": 1}
    etypes = {"edge1": 0}
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    expect_except = False
    try:
        RangePartitionBook(
            part_id, num_parts, node_map, edge_map, ntypes, etypes
        )
    except AssertionError:
        expect_except = True
    assert expect_except
    expect_except = False
    try:
        EdgePartitionPolicy(gpb, "edge1")
    except AssertionError:
        expect_except = True
    assert expect_except
589
590

    # heterogeneous, init via canonical etype
591
592
593
594
595
596
597
598
599
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {
        ("node1", "edge1", "node2"): F.tensor([[0, 5000], [5000, 10000]])
    }
    ntypes = {"node1": 0, "node2": 1}
    etypes = {("node1", "edge1", "node2"): 0}
600
601
    c_etype = list(etypes.keys())[0]
    gpb = RangePartitionBook(
602
603
604
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
    assert gpb.etypes == ["edge1"]
605
    assert gpb.canonical_etypes == [c_etype]
606
607
    assert gpb.to_canonical_etype("edge1") == c_etype
    assert gpb.to_canonical_etype(c_etype) == c_etype
608
609
    expect_except = False
    try:
610
611
        gpb.to_canonical_etype(("node1", "edge2", "node2"))
    except:
612
613
614
615
        expect_except = True
    assert expect_except
    expect_except = False
    try:
616
617
        gpb.to_canonical_etype("edge2")
    except:
618
619
620
        expect_except = True
    assert expect_except

621
    # NodePartitionPolicy
622
623
    node_policy = NodePartitionPolicy(gpb, "node1")
    assert node_policy.type_name == "node1"
624
625
626
    assert node_policy.policy_str == "node~node1"
    assert node_policy.part_id == part_id
    assert node_policy.is_node
627
    assert node_policy.get_data_name("x").is_node()
628
629
630
631
632
633
634
635
    local_ids = th.arange(0, 1000)
    global_ids = local_ids + 1000
    assert th.equal(node_policy.to_local(global_ids), local_ids)
    assert th.all(node_policy.to_partid(global_ids) == part_id)
    assert node_policy.get_part_size() == 1000
    assert node_policy.get_size() == 2000

    # EdgePartitionPolicy
636
637
    edge_policy = EdgePartitionPolicy(gpb, c_etype)
    assert edge_policy.type_name == c_etype
638
639
640
    assert edge_policy.policy_str == "edge~node1:edge1:node2"
    assert edge_policy.part_id == part_id
    assert not edge_policy.is_node
641
    assert not edge_policy.get_data_name("x").is_node()
642
643
644
645
646
647
    local_ids = th.arange(0, 5000)
    global_ids = local_ids + 5000
    assert th.equal(edge_policy.to_local(global_ids), local_ids)
    assert th.all(edge_policy.to_partid(global_ids) == part_id)
    assert edge_policy.get_part_size() == 5000
    assert edge_policy.get_size() == 10000
648

649
650
651
652
653
654
655
    expect_except = False
    try:
        HeteroDataName(False, "edge1", "feat")
    except:
        expect_except = True
    assert expect_except
    data_name = HeteroDataName(False, c_etype, "feat")
656
    assert data_name.get_type() == c_etype
657
658
659


def test_UnknownPartitionBook():
660
661
    node_map = {"_N": {0: 0, 1: 1, 2: 2}}
    edge_map = {"_N:_E:_N": {0: 0, 1: 1, 2: 2}}
662
663
664
665
666
667
668

    part_metadata = {
        "num_parts": 1,
        "num_nodes": len(node_map),
        "num_edges": len(edge_map),
        "node_map": node_map,
        "edge_map": edge_map,
669
        "graph_name": "test_graph",
670
671
672
673
674
    }

    with tempfile.TemporaryDirectory() as test_dir:
        part_config = os.path.join(test_dir, "test_graph.json")
        with open(part_config, "w") as file:
675
            json.dump(part_metadata, file, indent=4)
676
677
678
679
680
        try:
            load_partition_book(part_config, 0)
        except Exception as e:
            if not isinstance(e, TypeError):
                raise e