"src/diffusers/pipelines/latte/pipeline_latte.py" did not exist on "a8523bffa844752f8080e2ee675f91c32e392cf0"
test_partition.py 25.2 KB
Newer Older
1
import os
2
3

import backend as F
4
import torch as th
5
import dgl
6
import json
7
import numpy as np
8
import pytest
9
import tempfile
10
from dgl import function as fn
11
12
from dgl.distributed import (
    load_partition,
13
    load_partition_book,
14
15
16
17
18
19
20
21
22
23
24
25
    load_partition_feats,
    partition_graph,
)
from dgl.distributed.graph_partition_book import (
    DEFAULT_ETYPE,
    DEFAULT_NTYPE,
    EdgePartitionPolicy,
    HeteroDataName,
    NodePartitionPolicy,
    RangePartitionBook,
    _etype_tuple_to_str,
)
26
27
from dgl.distributed.partition import (
    RESERVED_FIELD_DTYPE,
28
    _get_inner_edge_mask,
29
30
31
    _get_inner_node_mask,
)
from scipy import sparse as spsp
32
from utils import reset_envs
33

34
35
36
37
38
39
40

def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
41

42

43
44
45
46
def _verify_partition_formats(part_g, formats):
    # verify saved graph formats
    if formats is None:
        assert "coo" in part_g.formats()["created"]
47
    else:
48
49
        for format in formats:
            assert format in part_g.formats()["created"]
50
51


52
def create_random_graph(n):
53
54
55
    arr = (
        spsp.random(n, n, density=0.001, format="coo", random_state=100) != 0
    ).astype(np.int64)
56
    return dgl.from_scipy(arr)
57

58

59
def create_random_hetero():
60
    num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
61
62
63
64
65
66
    etypes = [
        ("n1", "r1", "n2"),
        ("n2", "r1", "n1"),
        ("n1", "r2", "n3"),
        ("n2", "r3", "n3"),
    ]
67
68
69
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
70
71
72
73
74
75
76
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
77
78
79
        edges[etype] = (arr.row, arr.col)
    return dgl.heterograph(edges, num_nodes)

80

81
def verify_hetero_graph(g, parts):
82
    num_nodes = {ntype: 0 for ntype in g.ntypes}
83
    num_edges = {etype: 0 for etype in g.canonical_etypes}
84
85
    for part in parts:
        assert len(g.ntypes) == len(F.unique(part.ndata[dgl.NTYPE]))
86
        assert len(g.canonical_etypes) == len(F.unique(part.edata[dgl.ETYPE]))
87
88
89
90
91
        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0)
            num_nodes[ntype] += num_inner_nodes
92
        for etype in g.canonical_etypes:
93
94
95
96
97
98
            etype_id = g.get_etype_id(etype)
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
            num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0)
            num_edges[etype] += num_inner_edges
    # Verify the number of nodes are correct.
    for ntype in g.ntypes:
99
100
101
102
103
        print(
            "node {}: {}, {}".format(
                ntype, g.number_of_nodes(ntype), num_nodes[ntype]
            )
        )
104
105
        assert g.number_of_nodes(ntype) == num_nodes[ntype]
    # Verify the number of edges are correct.
106
    for etype in g.canonical_etypes:
107
108
109
110
111
        print(
            "edge {}: {}, {}".format(
                etype, g.number_of_edges(etype), num_edges[etype]
            )
        )
112
113
        assert g.number_of_edges(etype) == num_edges[etype]

114
    nids = {ntype: [] for ntype in g.ntypes}
115
    eids = {etype: [] for etype in g.canonical_etypes}
116
    for part in parts:
117
        _, _, eid = part.edges(form="all")
118
119
        etype_arr = F.gather_row(part.edata[dgl.ETYPE], eid)
        eid_type = F.gather_row(part.edata[dgl.EID], eid)
120
        for etype in g.canonical_etypes:
121
122
123
124
            etype_id = g.get_etype_id(etype)
            eids[etype].append(F.boolean_mask(eid_type, etype_arr == etype_id))
            # Make sure edge Ids fall into a range.
            inner_edge_mask = _get_inner_edge_mask(part, etype_id)
125
126
127
128
129
130
            inner_eids = np.sort(
                F.asnumpy(F.boolean_mask(part.edata[dgl.EID], inner_edge_mask))
            )
            assert np.all(
                inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1)
            )
131
132
133
134
135
136

        for ntype in g.ntypes:
            ntype_id = g.get_ntype_id(ntype)
            # Make sure inner nodes have Ids fall into a range.
            inner_node_mask = _get_inner_node_mask(part, ntype_id)
            inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
137
138
139
140
141
142
143
144
145
            assert np.all(
                F.asnumpy(
                    inner_nids
                    == F.arange(
                        F.as_scalar(inner_nids[0]),
                        F.as_scalar(inner_nids[-1]) + 1,
                    )
                )
            )
146
147
148
149
150
151
152
153
154
155
156
157
158
            nids[ntype].append(inner_nids)

    for ntype in nids:
        nids_type = F.cat(nids[ntype], 0)
        uniq_ids = F.unique(nids_type)
        # We should get all nodes.
        assert len(uniq_ids) == g.number_of_nodes(ntype)
    for etype in eids:
        eids_type = F.cat(eids[etype], 0)
        uniq_ids = F.unique(eids_type)
        assert len(uniq_ids) == g.number_of_edges(etype)
    # TODO(zhengda) this doesn't check 'part_id'

159
160
161
162

def verify_graph_feats(
    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids
):
163
164
    for ntype in g.ntypes:
        ntype_id = g.get_ntype_id(ntype)
165
        inner_node_mask = _get_inner_node_mask(part, ntype_id)
166
        inner_nids = F.boolean_mask(part.ndata[dgl.NID], inner_node_mask)
167
168
169
170
171
        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
        partid = gpb.nid2partid(inner_type_nids, ntype)
        assert np.all(F.asnumpy(ntype_ids) == ntype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

172
        orig_id = orig_nids[ntype][inner_type_nids]
173
174
        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)

175
        for name in g.nodes[ntype].data:
176
            if name in [dgl.NID, "inner_node"]:
177
178
                continue
            true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)
179
            ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids)
180
181
            assert np.all(F.asnumpy(ndata == true_feats))

182
    for etype in g.canonical_etypes:
183
184
        etype_id = g.get_etype_id(etype)
        inner_edge_mask = _get_inner_edge_mask(part, etype_id)
185
        inner_eids = F.boolean_mask(part.edata[dgl.EID], inner_edge_mask)
186
187
188
189
190
        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
        partid = gpb.eid2partid(inner_type_eids, etype)
        assert np.all(F.asnumpy(etype_ids) == etype_id)
        assert np.all(F.asnumpy(partid) == gpb.partid)

191
        orig_id = orig_eids[etype][inner_type_eids]
192
193
194
        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)

        for name in g.edges[etype].data:
195
            if name in [dgl.EID, "inner_edge"]:
196
197
                continue
            true_feats = F.gather_row(g.edges[etype].data[name], orig_id)
198
199
200
            edata = F.gather_row(
                edge_feats[_etype_tuple_to_str(etype) + "/" + name], local_eids
            )
201
202
            assert np.all(F.asnumpy(edata == true_feats))

203
204
205
206
207
208
209
210
211

def check_hetero_partition(
    hg,
    part_method,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
212
213
214
215
216
    test_ntype = "n1"
    test_etype = ("n1", "r1", "n2")
    hg.nodes[test_ntype].data["labels"] = F.arange(0, hg.num_nodes(test_ntype))
    hg.nodes[test_ntype].data["feats"] = F.tensor(
        np.random.randn(hg.num_nodes(test_ntype), 10), F.float32
217
    )
218
219
    hg.edges[test_etype].data["feats"] = F.tensor(
        np.random.randn(hg.num_edges(test_etype), 10), F.float32
220
    )
221
    hg.edges[test_etype].data["labels"] = F.arange(0, hg.num_edges(test_etype))
222
223
    num_hops = 1

224
225
226
227
228
229
230
231
232
233
234
    orig_nids, orig_eids = partition_graph(
        hg,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
235
    assert len(orig_nids) == len(hg.ntypes)
236
    assert len(orig_eids) == len(hg.canonical_etypes)
237
238
    for ntype in hg.ntypes:
        assert len(orig_nids[ntype]) == hg.number_of_nodes(ntype)
239
    for etype in hg.canonical_etypes:
240
        assert len(orig_eids[etype]) == hg.number_of_edges(etype)
241
    parts = []
242
243
    shuffled_labels = []
    shuffled_elabels = []
244
    for i in range(num_parts):
245
        part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition(
246
247
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
248
        _verify_partition_data_types(part_g)
249
        _verify_partition_formats(part_g, graph_formats)
250
251
252
        if not load_feats:
            assert not node_feats
            assert not edge_feats
253
254
255
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
256
257
        if num_trainers_per_machine > 1:
            for ntype in hg.ntypes:
258
                name = ntype + "/trainer_id"
259
                assert name in node_feats
260
261
262
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
263
264
                assert np.all(F.asnumpy(part_ids) == i)

265
266
            for etype in hg.canonical_etypes:
                name = _etype_tuple_to_str(etype) + "/trainer_id"
267
                assert name in edge_feats
268
269
270
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
271
                assert np.all(F.asnumpy(part_ids) == i)
272
273
274
275
276
277
278
279
280
281
282
283
        # Verify the mapping between the reshuffled IDs and the original IDs.
        # These are partition-local IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        # These are reshuffled global homogeneous IDs.
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        # These are reshuffled per-type IDs.
        src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)
        dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)
        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)
        # These are original per-type IDs.
284
        for etype_id, etype in enumerate(hg.canonical_etypes):
285
            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)
286
287
288
            src_ntype_ids1 = F.boolean_mask(
                src_ntype_ids, etype_ids == etype_id
            )
289
            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)
290
291
292
            dst_ntype_ids1 = F.boolean_mask(
                dst_ntype_ids, etype_ids == etype_id
            )
293
294
295
296
297
298
299
300
301
302
303
            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)
            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))
            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))
            src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])]
            dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])]
            orig_src_ids1 = F.gather_row(orig_nids[src_ntype], part_src_ids1)
            orig_dst_ids1 = F.gather_row(orig_nids[dst_ntype], part_dst_ids1)
            orig_eids1 = F.gather_row(orig_eids[etype], part_eids1)
            orig_eids2 = hg.edge_ids(orig_src_ids1, orig_dst_ids1, etype=etype)
            assert len(orig_eids1) == len(orig_eids2)
            assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
304
        parts.append(part_g)
305
306
307
        verify_graph_feats(
            hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
        )
308

309
310
311
312
        shuffled_labels.append(node_feats[test_ntype + "/labels"])
        shuffled_elabels.append(
            edge_feats[_etype_tuple_to_str(test_etype) + "/labels"]
        )
313
314
    verify_hetero_graph(hg, parts)

315
316
317
    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
    shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0))
    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
318
319
320
    orig_elabels = np.zeros(
        shuffled_elabels.shape, dtype=shuffled_elabels.dtype
    )
321
322
323
324
325
326
    orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels
    orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels
    assert np.all(orig_labels == F.asnumpy(hg.nodes[test_ntype].data["labels"]))
    assert np.all(
        orig_elabels == F.asnumpy(hg.edges[test_etype].data["labels"])
    )
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343


def check_partition(
    g,
    part_method,
    num_parts=4,
    num_trainers_per_machine=1,
    load_feats=True,
    graph_formats=None,
):
    g.ndata["labels"] = F.arange(0, g.number_of_nodes())
    g.ndata["feats"] = F.tensor(
        np.random.randn(g.number_of_nodes(), 10), F.float32
    )
    g.edata["feats"] = F.tensor(
        np.random.randn(g.number_of_edges(), 10), F.float32
    )
344
345
    g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h"))
    g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh"))
346
    num_hops = 2
Da Zheng's avatar
Da Zheng committed
347

348
349
350
351
352
353
354
355
356
357
358
    orig_nids, orig_eids = partition_graph(
        g,
        "test",
        num_parts,
        "/tmp/partition",
        num_hops=num_hops,
        part_method=part_method,
        return_mapping=True,
        num_trainers_per_machine=num_trainers_per_machine,
        graph_formats=graph_formats,
    )
Da Zheng's avatar
Da Zheng committed
359
    part_sizes = []
360
361
    shuffled_labels = []
    shuffled_edata = []
362
    for i in range(num_parts):
363
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
364
365
            "/tmp/partition/test.json", i, load_feats=load_feats
        )
366
        _verify_partition_data_types(part_g)
367
        _verify_partition_formats(part_g, graph_formats)
368
369
370
        if not load_feats:
            assert not node_feats
            assert not edge_feats
371
372
373
            node_feats, edge_feats = load_partition_feats(
                "/tmp/partition/test.json", i
            )
374
375
        if num_trainers_per_machine > 1:
            for ntype in g.ntypes:
376
                name = ntype + "/trainer_id"
377
                assert name in node_feats
378
379
380
                part_ids = F.floor_div(
                    node_feats[name], num_trainers_per_machine
                )
381
382
                assert np.all(F.asnumpy(part_ids) == i)

383
384
            for etype in g.canonical_etypes:
                name = _etype_tuple_to_str(etype) + "/trainer_id"
385
                assert name in edge_feats
386
387
388
                part_ids = F.floor_div(
                    edge_feats[name], num_trainers_per_machine
                )
389
                assert np.all(F.asnumpy(part_ids) == i)
390
391

        # Check the metadata
Da Zheng's avatar
Da Zheng committed
392
393
394
395
396
397
        assert gpb._num_nodes() == g.number_of_nodes()
        assert gpb._num_edges() == g.number_of_edges()

        assert gpb.num_partitions() == num_parts
        gpb_meta = gpb.metadata()
        assert len(gpb_meta) == num_parts
398
399
400
        assert len(gpb.partid2nids(i)) == gpb_meta[i]["num_nodes"]
        assert len(gpb.partid2eids(i)) == gpb_meta[i]["num_edges"]
        part_sizes.append((gpb_meta[i]["num_nodes"], gpb_meta[i]["num_edges"]))
Da Zheng's avatar
Da Zheng committed
401

402
        nid = F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata["inner_node"])
403
        local_nid = gpb.nid2localnid(nid, i)
404
        assert F.dtype(local_nid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
405
        assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))
406
        eid = F.boolean_mask(part_g.edata[dgl.EID], part_g.edata["inner_edge"])
407
        local_eid = gpb.eid2localeid(eid, i)
408
        assert F.dtype(local_eid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
409
        assert np.all(F.asnumpy(local_eid) == np.arange(0, len(local_eid)))
410
411

        # Check the node map.
412
413
414
415
        local_nodes = F.boolean_mask(
            part_g.ndata[dgl.NID], part_g.ndata["inner_node"]
        )
        llocal_nodes = F.nonzero_1d(part_g.ndata["inner_node"])
416
        local_nodes1 = gpb.partid2nids(i)
417
        assert F.dtype(local_nodes1) in (F.int32, F.int64)
418
419
420
        assert np.all(
            np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(local_nodes1))
        )
421
        assert np.all(F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)))
422
423

        # Check the edge map.
424
425
426
427
        local_edges = F.boolean_mask(
            part_g.edata[dgl.EID], part_g.edata["inner_edge"]
        )
        llocal_edges = F.nonzero_1d(part_g.edata["inner_edge"])
428
        local_edges1 = gpb.partid2eids(i)
429
        assert F.dtype(local_edges1) in (F.int32, F.int64)
430
431
432
        assert np.all(
            np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1))
        )
433
        assert np.all(F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)))
434

435
436
437
438
439
440
441
442
443
444
445
446
        # Verify the mapping between the reshuffled IDs and the original IDs.
        part_src_ids, part_dst_ids = part_g.edges()
        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
        part_eids = part_g.edata[dgl.EID]
        orig_src_ids = F.gather_row(orig_nids, part_src_ids)
        orig_dst_ids = F.gather_row(orig_nids, part_dst_ids)
        orig_eids1 = F.gather_row(orig_eids, part_eids)
        orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids)
        assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0]
        assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))

447
448
449
450
451
452
453
454
455
456
        local_orig_nids = orig_nids[part_g.ndata[dgl.NID]]
        local_orig_eids = orig_eids[part_g.edata[dgl.EID]]
        part_g.ndata["feats"] = F.gather_row(
            g.ndata["feats"], local_orig_nids
        )
        part_g.edata["feats"] = F.gather_row(
            g.edata["feats"], local_orig_eids
        )
        local_nodes = orig_nids[local_nodes]
        local_edges = orig_eids[local_edges]
457

458
459
        part_g.update_all(fn.copy_u("feats", "msg"), fn.sum("msg", "h"))
        part_g.update_all(fn.copy_e("feats", "msg"), fn.sum("msg", "eh"))
460
461
462
463
464
465
466
467
468
469
470
471
        assert F.allclose(
            F.gather_row(g.ndata["h"], local_nodes),
            F.gather_row(part_g.ndata["h"], llocal_nodes),
        )
        assert F.allclose(
            F.gather_row(g.ndata["eh"], local_nodes),
            F.gather_row(part_g.ndata["eh"], llocal_nodes),
        )

        for name in ["labels", "feats"]:
            assert "_N/" + name in node_feats
            assert node_feats["_N/" + name].shape[0] == len(local_nodes)
472
            true_feats = F.gather_row(g.ndata[name], local_nodes)
473
            ndata = F.gather_row(node_feats["_N/" + name], local_nid)
474
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata))
475
        for name in ["feats"]:
476
477
478
            efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + "/" + name
            assert efeat_name in edge_feats
            assert edge_feats[efeat_name].shape[0] == len(local_edges)
479
            true_feats = F.gather_row(g.edata[name], local_edges)
480
            edata = F.gather_row(edge_feats[efeat_name], local_eid)
481
482
483
            assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata))

        # This only works if node/edge IDs are shuffled.
484
485
        shuffled_labels.append(node_feats["_N/labels"])
        shuffled_edata.append(edge_feats["_N:_E:_N/feats"])
486
487

    # Verify that we can reconstruct node/edge data for original IDs.
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
    shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))
    orig_labels = np.zeros(
        shuffled_labels.shape, dtype=shuffled_labels.dtype
    )
    orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)
    orig_labels[F.asnumpy(orig_nids)] = shuffled_labels
    orig_edata[F.asnumpy(orig_eids)] = shuffled_edata
    assert np.all(orig_labels == F.asnumpy(g.ndata["labels"]))
    assert np.all(orig_edata == F.asnumpy(g.edata["feats"]))

    node_map = []
    edge_map = []
    for i, (num_nodes, num_edges) in enumerate(part_sizes):
        node_map.append(np.ones(num_nodes) * i)
        edge_map.append(np.ones(num_edges) * i)
    node_map = np.concatenate(node_map)
    edge_map = np.concatenate(edge_map)
    nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))
    assert F.dtype(nid2pid) in (F.int32, F.int64)
    assert np.all(F.asnumpy(nid2pid) == node_map)
    eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))
    assert F.dtype(eid2pid) in (F.int32, F.int64)
    assert np.all(F.asnumpy(eid2pid) == edge_map)
Da Zheng's avatar
Da Zheng committed
512

513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("num_trainers_per_machine", [1, 4])
@pytest.mark.parametrize("load_feats", [True, False])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["coo", "csc"], ["coo", "csc", "csr"]]
)
def test_partition(
    part_method,
    num_parts,
    num_trainers_per_machine,
    load_feats,
    graph_formats,
):
    os.environ["DGL_DIST_DEBUG"] = "1"
    if part_method == "random" and num_parts > 1:
        num_trainers_per_machine = 1
    g = create_random_graph(1000)
    check_partition(
        g,
        part_method,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
540
    hg = create_random_hetero()
541
542
543
544
545
546
547
548
    check_hetero_partition(
        hg,
        part_method,
        num_parts,
        num_trainers_per_machine,
        load_feats,
        graph_formats,
    )
549
    reset_envs()
Da Zheng's avatar
Da Zheng committed
550

551
def test_RangePartitionBook():
552
    part_id = 1
553
    num_parts = 2
554

555
    # homogeneous
556
557
558
559
    node_map = {DEFAULT_NTYPE: F.tensor([[0, 1000], [1000, 2000]])}
    edge_map = {DEFAULT_ETYPE: F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {DEFAULT_NTYPE: 0}
    etypes = {DEFAULT_ETYPE: 0}
560
    gpb = RangePartitionBook(
561
562
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
563
564
565
    assert gpb.etypes == [DEFAULT_ETYPE[1]]
    assert gpb.canonical_etypes == [DEFAULT_ETYPE]
    assert gpb.to_canonical_etype(DEFAULT_ETYPE[1]) == DEFAULT_ETYPE
566

567
568
569
570
    node_policy = NodePartitionPolicy(gpb, DEFAULT_NTYPE)
    assert node_policy.type_name == DEFAULT_NTYPE
    edge_policy = EdgePartitionPolicy(gpb, DEFAULT_ETYPE)
    assert edge_policy.type_name == DEFAULT_ETYPE
571

572
    # Init via etype is not supported
573
574
575
576
577
578
579
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {"edge1": F.tensor([[0, 5000], [5000, 10000]])}
    ntypes = {"node1": 0, "node2": 1}
    etypes = {"edge1": 0}
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    expect_except = False
    try:
        RangePartitionBook(
            part_id, num_parts, node_map, edge_map, ntypes, etypes
        )
    except AssertionError:
        expect_except = True
    assert expect_except
    expect_except = False
    try:
        EdgePartitionPolicy(gpb, "edge1")
    except AssertionError:
        expect_except = True
    assert expect_except
594
595

    # heterogeneous, init via canonical etype
596
597
598
599
600
601
602
603
604
    node_map = {
        "node1": F.tensor([[0, 1000], [1000, 2000]]),
        "node2": F.tensor([[0, 1000], [1000, 2000]]),
    }
    edge_map = {
        ("node1", "edge1", "node2"): F.tensor([[0, 5000], [5000, 10000]])
    }
    ntypes = {"node1": 0, "node2": 1}
    etypes = {("node1", "edge1", "node2"): 0}
605
606
    c_etype = list(etypes.keys())[0]
    gpb = RangePartitionBook(
607
608
609
        part_id, num_parts, node_map, edge_map, ntypes, etypes
    )
    assert gpb.etypes == ["edge1"]
610
    assert gpb.canonical_etypes == [c_etype]
611
612
    assert gpb.to_canonical_etype("edge1") == c_etype
    assert gpb.to_canonical_etype(c_etype) == c_etype
613
614
    expect_except = False
    try:
615
616
        gpb.to_canonical_etype(("node1", "edge2", "node2"))
    except:
617
618
619
620
        expect_except = True
    assert expect_except
    expect_except = False
    try:
621
622
        gpb.to_canonical_etype("edge2")
    except:
623
624
625
        expect_except = True
    assert expect_except

626
    # NodePartitionPolicy
627
628
    node_policy = NodePartitionPolicy(gpb, "node1")
    assert node_policy.type_name == "node1"
629
630
631
632
633
634
635
636
637
638
639
640
    assert node_policy.policy_str == "node~node1"
    assert node_policy.part_id == part_id
    assert node_policy.is_node
    assert node_policy.get_data_name('x').is_node()
    local_ids = th.arange(0, 1000)
    global_ids = local_ids + 1000
    assert th.equal(node_policy.to_local(global_ids), local_ids)
    assert th.all(node_policy.to_partid(global_ids) == part_id)
    assert node_policy.get_part_size() == 1000
    assert node_policy.get_size() == 2000

    # EdgePartitionPolicy
641
642
    edge_policy = EdgePartitionPolicy(gpb, c_etype)
    assert edge_policy.type_name == c_etype
643
644
645
646
647
648
649
650
651
652
    assert edge_policy.policy_str == "edge~node1:edge1:node2"
    assert edge_policy.part_id == part_id
    assert not edge_policy.is_node
    assert not edge_policy.get_data_name('x').is_node()
    local_ids = th.arange(0, 5000)
    global_ids = local_ids + 5000
    assert th.equal(edge_policy.to_local(global_ids), local_ids)
    assert th.all(edge_policy.to_partid(global_ids) == part_id)
    assert edge_policy.get_part_size() == 5000
    assert edge_policy.get_size() == 10000
653

654
655
656
657
658
659
660
    expect_except = False
    try:
        HeteroDataName(False, "edge1", "feat")
    except:
        expect_except = True
    assert expect_except
    data_name = HeteroDataName(False, c_etype, "feat")
661
    assert data_name.get_type() == c_etype
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685


def test_UnknownPartitionBook():
    node_map = {'_N': {0:0, 1:1, 2:2}}
    edge_map = {'_N:_E:_N': {0:0, 1:1, 2:2}}

    part_metadata = {
        "num_parts": 1,
        "num_nodes": len(node_map),
        "num_edges": len(edge_map),
        "node_map": node_map,
        "edge_map": edge_map,
        "graph_name": "test_graph"
    }

    with tempfile.TemporaryDirectory() as test_dir:
        part_config = os.path.join(test_dir, "test_graph.json")
        with open(part_config, "w") as file:
            json.dump(part_metadata, file, indent = 4)
        try:
            load_partition_book(part_config, 0)
        except Exception as e:
            if not isinstance(e, TypeError):
                raise e