test_dist_part.py 12.4 KB
Newer Older
1
2
3
import json
import os
import tempfile
4

5
6
import dgl

7
import numpy as np
8
import pyarrow.parquet as pq
9
10
import pytest
import torch
11
12
13
14
15
16
17
18
from dgl.data.utils import load_graphs, load_tensors
from dgl.distributed.partition import (
    _etype_tuple_to_str,
    _get_inner_edge_mask,
    _get_inner_node_mask,
    load_partition,
    RESERVED_FIELD_DTYPE,
)
19

20
from distpartitioning import array_readwriter
21
from distpartitioning.utils import generate_read_list
22
from pytest_utils import create_chunked_dataset
23

24
25
26
27
28
from tools.verification_utils import (
    verify_graph_feats,
    verify_partition_data_types,
    verify_partition_formats,
)
29

30

31
32
def _test_chunk_graph(
    num_chunks,
33
34
35
36
37
38
39
    data_fmt="numpy",
    edges_fmt="csv",
    vector_rows=False,
    num_chunks_nodes=None,
    num_chunks_edges=None,
    num_chunks_node_data=None,
    num_chunks_edge_data=None,
40
):
41
    with tempfile.TemporaryDirectory() as root_dir:
42
43
44
45
46
47
48
49
50
51
52
        g = create_chunked_dataset(
            root_dir,
            num_chunks,
            data_fmt=data_fmt,
            edges_fmt=edges_fmt,
            vector_rows=vector_rows,
            num_chunks_nodes=num_chunks_nodes,
            num_chunks_edges=num_chunks_edges,
            num_chunks_node_data=num_chunks_node_data,
            num_chunks_edge_data=num_chunks_edge_data,
        )
53
54

        # check metadata.json
55
56
        output_dir = os.path.join(root_dir, "chunked-data")
        json_file = os.path.join(output_dir, "metadata.json")
57
        assert os.path.isfile(json_file)
58
        with open(json_file, "rb") as f:
59
            meta_data = json.load(f)
60
61
        assert meta_data["graph_name"] == "mag240m"
        assert len(meta_data["num_nodes_per_chunk"][0]) == num_chunks
62
63

        # check edge_index
64
        output_edge_index_dir = os.path.join(output_dir, "edge_index")
65
66
        for c_etype in g.canonical_etypes:
            c_etype_str = _etype_tuple_to_str(c_etype)
67
68
69
70
71
            if num_chunks_edges is None:
                n_chunks = num_chunks
            else:
                n_chunks = num_chunks_edges
            for i in range(n_chunks):
72
                fname = os.path.join(
73
                    output_edge_index_dir, f"{c_etype_str}{i}.txt"
74
                )
75
                assert os.path.isfile(fname)
76
                if edges_fmt == "csv":
77
78
79
80
81
                    with open(fname, "r") as f:
                        header = f.readline()
                        num1, num2 = header.rstrip().split(" ")
                        assert isinstance(int(num1), int)
                        assert isinstance(int(num2), int)
82
                elif edges_fmt == "parquet":
83
84
85
86
                    metadata = pq.read_metadata(fname)
                    assert metadata.num_columns == 2
                else:
                    assert False, f"Invalid edges_fmt: {edges_fmt}"
87

88
        # check node/edge_data
89
        suffix = "npy" if data_fmt == "numpy" else "parquet"
90
        reader_fmt_meta = {"name": data_fmt}
91
92

        def test_data(sub_dir, feat, expected_data, expected_shape, num_chunks):
93
            data = []
94
            for i in range(num_chunks):
95
96
97
98
99
                fname = os.path.join(sub_dir, f"{feat}-{i}.{suffix}")
                assert os.path.isfile(fname), f"{fname} cannot be found."
                feat_array = array_readwriter.get_array_parser(
                    **reader_fmt_meta
                ).read(fname)
100
101
102
103
104
105
106
107
                assert feat_array.shape[0] == expected_shape
                data.append(feat_array)
            data = np.concatenate(data, 0)
            assert torch.equal(torch.from_numpy(data), expected_data)

        output_node_data_dir = os.path.join(output_dir, "node_data")
        for ntype in g.ntypes:
            sub_dir = os.path.join(output_node_data_dir, ntype)
108
109
110
111
112
113
            if isinstance(num_chunks_node_data, int):
                chunks_data = num_chunks_node_data
            elif isinstance(num_chunks_node_data, dict):
                chunks_data = num_chunks_node_data.get(ntype, num_chunks)
            else:
                chunks_data = num_chunks
114
            for feat, data in g.nodes[ntype].data.items():
115
116
117
118
                if isinstance(chunks_data, dict):
                    n_chunks = chunks_data.get(feat, num_chunks)
                else:
                    n_chunks = chunks_data
119
120
121
122
123
124
125
                test_data(
                    sub_dir,
                    feat,
                    data,
                    g.num_nodes(ntype) // n_chunks,
                    n_chunks,
                )
126

127
        output_edge_data_dir = os.path.join(output_dir, "edge_data")
128
129
130
        for c_etype in g.canonical_etypes:
            c_etype_str = _etype_tuple_to_str(c_etype)
            sub_dir = os.path.join(output_edge_data_dir, c_etype_str)
131
132
133
134
135
136
            if isinstance(num_chunks_edge_data, int):
                chunks_data = num_chunks_edge_data
            elif isinstance(num_chunks_edge_data, dict):
                chunks_data = num_chunks_edge_data.get(c_etype, num_chunks)
            else:
                chunks_data = num_chunks
137
            for feat, data in g.edges[c_etype].data.items():
138
139
140
141
                if isinstance(chunks_data, dict):
                    n_chunks = chunks_data.get(feat, num_chunks)
                else:
                    n_chunks = chunks_data
142
143
144
145
146
147
148
                test_data(
                    sub_dir,
                    feat,
                    data,
                    g.num_edges(c_etype) // n_chunks,
                    n_chunks,
                )
149
150


151
@pytest.mark.parametrize("num_chunks", [1, 8])
152
153
@pytest.mark.parametrize("data_fmt", ["numpy", "parquet"])
@pytest.mark.parametrize("edges_fmt", ["csv", "parquet"])
154
155
def test_chunk_graph_basics(num_chunks, data_fmt, edges_fmt):
    _test_chunk_graph(num_chunks, data_fmt=data_fmt, edges_fmt=edges_fmt)
156

157

158
159
160
@pytest.mark.parametrize("num_chunks", [1, 8])
@pytest.mark.parametrize("vector_rows", [True, False])
def test_chunk_graph_vector_rows(num_chunks, vector_rows):
161
162
163
164
165
166
    _test_chunk_graph(
        num_chunks,
        data_fmt="parquet",
        edges_fmt="parquet",
        vector_rows=vector_rows,
    )
167

168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@pytest.mark.parametrize(
    "num_chunks, "
    "num_chunks_nodes, "
    "num_chunks_edges, "
    "num_chunks_node_data, "
    "num_chunks_edge_data",
    [
        [1, None, None, None, None],
        [8, None, None, None, None],
        [4, 4, 4, 8, 12],
        [4, 4, 4, {"paper": 10}, {("author", "writes", "paper"): 24}],
        [
            4,
            4,
            4,
            {"paper": {"feat": 10}},
            {("author", "writes", "paper"): {"year": 24}},
        ],
    ],
)
def test_chunk_graph_arbitrary_chunks(
    num_chunks,
    num_chunks_nodes,
    num_chunks_edges,
    num_chunks_node_data,
    num_chunks_edge_data,
):
    _test_chunk_graph(
        num_chunks,
        num_chunks_nodes=num_chunks_nodes,
        num_chunks_edges=num_chunks_edges,
        num_chunks_node_data=num_chunks_node_data,
        num_chunks_edge_data=num_chunks_edge_data,
    )


205
206
207
208
209
def _test_pipeline(
    num_chunks,
    num_parts,
    world_size,
    graph_formats=None,
210
    data_fmt="numpy",
211
212
213
    num_chunks_nodes=None,
    num_chunks_edges=None,
    num_chunks_node_data=None,
214
    num_chunks_edge_data=None,
215
    use_verify_partitions=False,
216
):
217

218
219
220
221
    if num_parts % world_size != 0:
        # num_parts should be a multiple of world_size
        return

222
    with tempfile.TemporaryDirectory() as root_dir:
223
224
225
226
227
228
229
230
231
        g = create_chunked_dataset(
            root_dir,
            num_chunks,
            data_fmt=data_fmt,
            num_chunks_nodes=num_chunks_nodes,
            num_chunks_edges=num_chunks_edges,
            num_chunks_node_data=num_chunks_node_data,
            num_chunks_edge_data=num_chunks_edge_data,
        )
232
233

        # Step1: graph partition
234
235
        in_dir = os.path.join(root_dir, "chunked-data")
        output_dir = os.path.join(root_dir, "parted_data")
236
        os.system(
237
238
            "python3 tools/partition_algo/random_partition.py "
            "--in_dir {} --out_dir {} --num_partitions {}".format(
239
240
241
                in_dir, output_dir, num_parts
            )
        )
242
243
244
        for ntype in ["author", "institution", "paper"]:
            fname = os.path.join(output_dir, "{}.txt".format(ntype))
            with open(fname, "r") as f:
245
246
247
248
                header = f.readline().rstrip()
                assert isinstance(int(header), int)

        # Step2: data dispatch
249
250
251
252
        partition_dir = os.path.join(root_dir, "parted_data")
        out_dir = os.path.join(root_dir, "partitioned")
        ip_config = os.path.join(root_dir, "ip_config.txt")
        with open(ip_config, "w") as f:
253
            for i in range(world_size):
254
                f.write(f"127.0.0.{i + 1}\n")
255
256
257
258
259
260

        cmd = "python3 tools/dispatch_data.py"
        cmd += f" --in-dir {in_dir}"
        cmd += f" --partitions-dir {partition_dir}"
        cmd += f" --out-dir {out_dir}"
        cmd += f" --ip-config {ip_config}"
261
        cmd += " --ssh-port 22"
262
263
264
        cmd += " --process-group-timeout 60"
        cmd += " --save-orig-nids"
        cmd += " --save-orig-eids"
265
        cmd += f" --graph-formats {graph_formats}" if graph_formats else ""
266
        os.system(cmd)
267

268
269
270
271
272
273
274
275
276
        # check if verify_partitions.py is used for validation.
        if use_verify_partitions:
            cmd = "python3 tools/verify_partitions.py "
            cmd += f" --orig-dataset-dir {in_dir}"
            cmd += f" --part-graph {out_dir}"
            cmd += f" --partitions-dir {output_dir}"
            os.system(cmd)
            return

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        # read original node/edge IDs
        def read_orig_ids(fname):
            orig_ids = {}
            for i in range(num_parts):
                ids_path = os.path.join(out_dir, f"part{i}", fname)
                part_ids = load_tensors(ids_path)
                for type, data in part_ids.items():
                    if type not in orig_ids:
                        orig_ids[type] = data
                    else:
                        orig_ids[type] = torch.cat((orig_ids[type], data))
            return orig_ids

        orig_nids = read_orig_ids("orig_nids.dgl")
        orig_eids = read_orig_ids("orig_eids.dgl")

        # load partitions and verify
        part_config = os.path.join(out_dir, "metadata.json")
295
        for i in range(num_parts):
296
297
298
            part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
                part_config, i
            )
299
300
301
            verify_partition_data_types(part_g)
            verify_partition_formats(part_g, graph_formats)
            verify_graph_feats(
302
303
                g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
            )
304
305


306
307
308
@pytest.mark.parametrize(
    "num_chunks, num_parts, world_size",
    [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],
309
)
310
311
def test_pipeline_basics(num_chunks, num_parts, world_size):
    _test_pipeline(num_chunks, num_parts, world_size)
312
    _test_pipeline(
313
        num_chunks, num_parts, world_size, use_verify_partitions=False
314
    )
315
316
317
318
319
320


@pytest.mark.parametrize(
    "graph_formats", [None, "csc", "coo,csc", "coo,csc,csr"]
)
def test_pipeline_formats(graph_formats):
321
    _test_pipeline(4, 4, 4, graph_formats)
322

323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
@pytest.mark.parametrize(
    "num_chunks, "
    "num_parts, "
    "world_size, "
    "num_chunks_node_data, "
    "num_chunks_edge_data",
    [
        # Test cases where no. of chunks more than
        # no. of partitions
        [8, 4, 4, 8, 8],
        [8, 4, 2, 8, 8],
        [9, 7, 5, 9, 9],
        [8, 8, 4, 8, 8],
        # Test cases where no. of chunks smaller
        # than no. of partitions
        [7, 8, 4, 7, 7],
        [1, 8, 4, 1, 1],
        [1, 4, 4, 1, 1],
        [3, 4, 4, 3, 3],
        [1, 4, 2, 1, 1],
        [3, 4, 2, 3, 3],
        [1, 5, 3, 1, 1],
    ],
)
def test_pipeline_arbitrary_chunks(
    num_chunks,
    num_parts,
    world_size,
    num_chunks_node_data,
    num_chunks_edge_data,
):
    _test_pipeline(
        num_chunks,
        num_parts,
        world_size,
        num_chunks_node_data=num_chunks_node_data,
        num_chunks_edge_data=num_chunks_edge_data,
    )


364
365
366
367
368
369
370
@pytest.mark.parametrize(
    "graph_formats", [None, "csc", "coo,csc", "coo,csc,csr"]
)
def test_pipeline_formats(graph_formats):
    _test_pipeline(4, 4, 4, graph_formats)


371
@pytest.mark.parametrize("data_fmt", ["numpy", "parquet"])
372
373
def test_pipeline_feature_format(data_fmt):
    _test_pipeline(4, 4, 4, data_fmt=data_fmt)
374
375
376
377
378
379
380
381


def test_utils_generate_read_list():
    read_list = generate_read_list(10, 4)
    assert np.array_equal(read_list[0], np.array([0, 1, 2]))
    assert np.array_equal(read_list[1], np.array([3, 4, 5]))
    assert np.array_equal(read_list[2], np.array([6, 7]))
    assert np.array_equal(read_list[3], np.array([8, 9]))