test_dist_part.py 13.1 KB
Newer Older
1
2
3
import json
import os
import tempfile
4
5
6
7

import numpy as np
import pytest
import torch
8
9
from utils import create_chunked_dataset

10
from distpartitioning import array_readwriter
11
from distpartitioning.utils import generate_read_list
12
13

import dgl
14
from dgl.data.utils import load_graphs, load_tensors
15
16
17
18
from dgl.distributed.partition import (RESERVED_FIELD_DTYPE,
                                       _etype_tuple_to_str,
                                       _get_inner_edge_mask,
                                       _get_inner_node_mask, load_partition)
19
20
21
22
23
24
25
26


def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
27

28
29
30
31
32
33
34
35
36
def _verify_partition_formats(part_g, formats):
    # Verify saved graph formats
    if formats is None:
        assert "coo" in part_g.formats()["created"]
    else:
        formats = formats.split(',')
        for format in formats:
            assert format in part_g.formats()["created"]

37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def _verify_graph_feats(
    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids
):
    for ntype in g.ntypes:
        ntype_id = g.get_ntype_id(ntype)
        inner_node_mask = _get_inner_node_mask(part, ntype_id)
        inner_nids = part.ndata[dgl.NID][inner_node_mask]
        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
        partid = gpb.nid2partid(inner_type_nids, ntype)
        assert np.all(ntype_ids.numpy() == ntype_id)
        assert np.all(partid.numpy() == gpb.partid)

        orig_id = orig_nids[ntype][inner_type_nids]
        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)

        for name in g.nodes[ntype].data:
            if name in [dgl.NID, "inner_node"]:
                continue
            true_feats = g.nodes[ntype].data[name][orig_id]
            ndata = node_feats[ntype + "/" + name][local_nids]
58
            assert np.array_equal(ndata.numpy(), true_feats.numpy())
59

60
    for etype in g.canonical_etypes:
61
62
63
64
65
66
67
68
        etype_id = g.get_etype_id(etype)
        inner_edge_mask = _get_inner_edge_mask(part, etype_id)
        inner_eids = part.edata[dgl.EID][inner_edge_mask]
        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
        partid = gpb.eid2partid(inner_type_eids, etype)
        assert np.all(etype_ids.numpy() == etype_id)
        assert np.all(partid.numpy() == gpb.partid)

69
        orig_id = orig_eids[_etype_tuple_to_str(etype)][inner_type_eids]
70
71
72
73
74
75
        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)

        for name in g.edges[etype].data:
            if name in [dgl.EID, "inner_edge"]:
                continue
            true_feats = g.edges[etype].data[name][orig_id]
76
            edata = edge_feats[_etype_tuple_to_str(etype) + "/" + name][local_eids]
77
            assert np.array_equal(edata.numpy(), true_feats.numpy())
78

79

80
81
82
83
84
85
86
87
def _test_chunk_graph(
    num_chunks,
    data_fmt = 'numpy',
    num_chunks_nodes = None,
    num_chunks_edges = None,
    num_chunks_node_data = None,
    num_chunks_edge_data = None
):
88
    with tempfile.TemporaryDirectory() as root_dir:
89

90
91
92
93
94
95
96
        g = create_chunked_dataset(root_dir, num_chunks,
                data_fmt=data_fmt,
                num_chunks_nodes=num_chunks_nodes,
                num_chunks_edges=num_chunks_edges,
                num_chunks_node_data=num_chunks_node_data,
                num_chunks_edge_data=num_chunks_edge_data
            )
97
98

        # check metadata.json
99
100
        output_dir = os.path.join(root_dir, "chunked-data")
        json_file = os.path.join(output_dir, "metadata.json")
101
        assert os.path.isfile(json_file)
102
        with open(json_file, "rb") as f:
103
            meta_data = json.load(f)
104
105
        assert meta_data["graph_name"] == "mag240m"
        assert len(meta_data["num_nodes_per_chunk"][0]) == num_chunks
106
107

        # check edge_index
108
        output_edge_index_dir = os.path.join(output_dir, "edge_index")
109
110
        for c_etype in g.canonical_etypes:
            c_etype_str = _etype_tuple_to_str(c_etype)
111
112
113
114
115
            if num_chunks_edges is None:
                n_chunks = num_chunks
            else:
                n_chunks = num_chunks_edges
            for i in range(n_chunks):
116
117
                fname = os.path.join(
                    output_edge_index_dir, f'{c_etype_str}{i}.txt'
118
                )
119
120
                assert os.path.isfile(fname)
                with open(fname, "r") as f:
121
                    header = f.readline()
122
                    num1, num2 = header.rstrip().split(" ")
123
124
125
                    assert isinstance(int(num1), int)
                    assert isinstance(int(num2), int)

126
        # check node/edge_data
127
128
        suffix = 'npy' if data_fmt=='numpy' else 'parquet'
        reader_fmt_meta = {"name": data_fmt}
129
130
131
        def test_data(
            sub_dir, feat, expected_data, expected_shape, num_chunks
        ):
132
            data = []
133
            for i in range(num_chunks):
134
                fname = os.path.join(sub_dir, f'{feat}-{i}.{suffix}')
135
                assert os.path.isfile(fname), f'{fname} cannot be found.'
136
137
138
                feat_array =  array_readwriter.get_array_parser(
                            **reader_fmt_meta
                        ).read(fname)
139
140
141
142
143
144
145
146
                assert feat_array.shape[0] == expected_shape
                data.append(feat_array)
            data = np.concatenate(data, 0)
            assert torch.equal(torch.from_numpy(data), expected_data)

        output_node_data_dir = os.path.join(output_dir, "node_data")
        for ntype in g.ntypes:
            sub_dir = os.path.join(output_node_data_dir, ntype)
147
148
149
150
151
152
            if isinstance(num_chunks_node_data, int):
                chunks_data = num_chunks_node_data
            elif isinstance(num_chunks_node_data, dict):
                chunks_data = num_chunks_node_data.get(ntype, num_chunks)
            else:
                chunks_data = num_chunks
153
            for feat, data in g.nodes[ntype].data.items():
154
155
156
157
158
159
                if isinstance(chunks_data, dict):
                    n_chunks = chunks_data.get(feat, num_chunks)
                else:
                    n_chunks = chunks_data
                test_data(sub_dir, feat, data, g.num_nodes(ntype) // n_chunks,
                    n_chunks)
160

161
        output_edge_data_dir = os.path.join(output_dir, "edge_data")
162
163
164
        for c_etype in g.canonical_etypes:
            c_etype_str = _etype_tuple_to_str(c_etype)
            sub_dir = os.path.join(output_edge_data_dir, c_etype_str)
165
166
167
168
169
170
            if isinstance(num_chunks_edge_data, int):
                chunks_data = num_chunks_edge_data
            elif isinstance(num_chunks_edge_data, dict):
                chunks_data = num_chunks_edge_data.get(c_etype, num_chunks)
            else:
                chunks_data = num_chunks
171
            for feat, data in g.edges[c_etype].data.items():
172
173
174
175
176
177
                if isinstance(chunks_data, dict):
                    n_chunks = chunks_data.get(feat, num_chunks)
                else:
                    n_chunks = chunks_data
                test_data(sub_dir, feat, data, g.num_edges(c_etype) // n_chunks,
                    n_chunks)
178
179


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@pytest.mark.parametrize("num_chunks", [1, 8])
@pytest.mark.parametrize("data_fmt", ['numpy', 'parquet'])
def test_chunk_graph_basics(num_chunks, data_fmt):
    _test_chunk_graph(num_chunks, data_fmt=data_fmt)


@pytest.mark.parametrize(
    "num_chunks, "
    "num_chunks_nodes, "
    "num_chunks_edges, "
    "num_chunks_node_data, "
    "num_chunks_edge_data",
    [
        [1, None, None, None, None],
        [8, None, None, None, None],
        [4, 4, 4, 8, 12],
        [4, 4, 4, {'paper': 10}, {('author', 'writes', 'paper'): 24}],
        [4, 4, 4, {'paper': {'feat': 10}},
            {('author', 'writes', 'paper'): {'year': 24}}],
    ]
)
def test_chunk_graph_arbitray_chunks(
    num_chunks,
    num_chunks_nodes,
    num_chunks_edges,
    num_chunks_node_data,
    num_chunks_edge_data
):
    _test_chunk_graph(
        num_chunks,
        num_chunks_nodes=num_chunks_nodes,
        num_chunks_edges=num_chunks_edges,
        num_chunks_node_data=num_chunks_node_data,
        num_chunks_edge_data=num_chunks_edge_data
    )


def _test_pipeline(
    num_chunks,
    num_parts,
    world_size,
    graph_formats=None,
    data_fmt='numpy',
    num_chunks_nodes=None,
    num_chunks_edges=None,
    num_chunks_node_data=None,
    num_chunks_edge_data=None
):
228
229
230
231
    if num_chunks < num_parts:
        # num_parts should less/equal than num_chunks
        return

232
233
234
235
    if num_parts % world_size != 0:
        # num_parts should be a multiple of world_size
        return

236
237
    with tempfile.TemporaryDirectory() as root_dir:

238
239
240
241
242
243
244
        g = create_chunked_dataset(root_dir, num_chunks,
                data_fmt=data_fmt,
                num_chunks_nodes=num_chunks_nodes,
                num_chunks_edges=num_chunks_edges,
                num_chunks_node_data=num_chunks_node_data,
                num_chunks_edge_data=num_chunks_edge_data
            )
245
246

        # Step1: graph partition
247
248
        in_dir = os.path.join(root_dir, "chunked-data")
        output_dir = os.path.join(root_dir, "parted_data")
249
        os.system(
250
251
            "python3 tools/partition_algo/random_partition.py "
            "--in_dir {} --out_dir {} --num_partitions {}".format(
252
253
254
                in_dir, output_dir, num_parts
            )
        )
255
256
257
        for ntype in ["author", "institution", "paper"]:
            fname = os.path.join(output_dir, "{}.txt".format(ntype))
            with open(fname, "r") as f:
258
259
260
261
                header = f.readline().rstrip()
                assert isinstance(int(header), int)

        # Step2: data dispatch
262
263
264
265
266
267
        partition_dir = os.path.join(root_dir, 'parted_data')
        out_dir = os.path.join(root_dir, 'partitioned')
        ip_config = os.path.join(root_dir, 'ip_config.txt')
        with open(ip_config, 'w') as f:
            for i in range(world_size):
                f.write(f'127.0.0.{i + 1}\n')
268
269
270
271
272
273

        cmd = "python3 tools/dispatch_data.py"
        cmd += f" --in-dir {in_dir}"
        cmd += f" --partitions-dir {partition_dir}"
        cmd += f" --out-dir {out_dir}"
        cmd += f" --ip-config {ip_config}"
274
        cmd += " --ssh-port 22"
275
276
277
        cmd += " --process-group-timeout 60"
        cmd += " --save-orig-nids"
        cmd += " --save-orig-eids"
278
        cmd += f" --graph-formats {graph_formats}" if graph_formats else ""
279
        os.system(cmd)
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        # read original node/edge IDs
        def read_orig_ids(fname):
            orig_ids = {}
            for i in range(num_parts):
                ids_path = os.path.join(out_dir, f"part{i}", fname)
                part_ids = load_tensors(ids_path)
                for type, data in part_ids.items():
                    if type not in orig_ids:
                        orig_ids[type] = data
                    else:
                        orig_ids[type] = torch.cat((orig_ids[type], data))
            return orig_ids

        orig_nids = read_orig_ids("orig_nids.dgl")
        orig_eids = read_orig_ids("orig_eids.dgl")

        # load partitions and verify
        part_config = os.path.join(out_dir, "metadata.json")
299
        for i in range(num_parts):
300
301
302
            part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
                part_config, i
            )
303
            _verify_partition_data_types(part_g)
304
            _verify_partition_formats(part_g, graph_formats)
305
306
307
            _verify_graph_feats(
                g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
            )
308
309


310
311
312
@pytest.mark.parametrize("num_chunks, num_parts, world_size",
    [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]]
)
313
314
def test_pipeline_basics(num_chunks, num_parts, world_size):
    _test_pipeline(num_chunks, num_parts, world_size)
315
316
317
318
319
320


@pytest.mark.parametrize(
    "graph_formats", [None, "csc", "coo,csc", "coo,csc,csr"]
)
def test_pipeline_formats(graph_formats):
321
    _test_pipeline(4, 4, 4, graph_formats)
322

323
324

@pytest.mark.parametrize(
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    "num_chunks, "
    "num_parts, "
    "world_size, "
    "num_chunks_node_data, "
    "num_chunks_edge_data",
    [
        [8, 4, 2, 20, 25],
        [9, 7, 5, 3, 11],
        [8, 8, 4, 3, 5],
        [8, 4, 2, {'paper': {'feat': 11, 'year': 1}},
            {('author', 'writes', 'paper'): {'year': 24}}],
    ]
)
def test_pipeline_arbitray_chunks(
    num_chunks,
    num_parts,
    world_size,
    num_chunks_node_data,
    num_chunks_edge_data,
):
    _test_pipeline(
        num_chunks,
        num_parts,
        world_size,
        num_chunks_node_data=num_chunks_node_data,
        num_chunks_edge_data=num_chunks_edge_data,
    )


@pytest.mark.parametrize(
    "graph_formats", [None, "csc", "coo,csc", "coo,csc,csr"]
)
def test_pipeline_formats(graph_formats):
    _test_pipeline(4, 4, 4, graph_formats)


@pytest.mark.parametrize(
    "data_fmt", ["numpy", "parquet"]
363
364
365
)
def test_pipeline_feature_format(data_fmt):
    _test_pipeline(4, 4, 4, data_fmt=data_fmt)
366
367
368
369
370
371
372
373


def test_utils_generate_read_list():
    read_list = generate_read_list(10, 4)
    assert np.array_equal(read_list[0], np.array([0, 1, 2]))
    assert np.array_equal(read_list[1], np.array([3, 4, 5]))
    assert np.array_equal(read_list[2], np.array([6, 7]))
    assert np.array_equal(read_list[3], np.array([8, 9]))