test_dist_part.py 11.2 KB
Newer Older
1
2
3
import json
import os
import tempfile
4
5
6
7
8
import unittest

import numpy as np
import pytest
import torch
9
from chunk_graph import chunk_graph
10
11
12
from create_chunked_dataset import create_chunked_dataset

import dgl
13
from dgl.data.utils import load_graphs, load_tensors
14
15
16
17
18
19
20
21
22
from dgl.distributed.partition import RESERVED_FIELD_DTYPE


def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
23
24


25
26
@pytest.mark.parametrize("num_chunks", [1, 8])
def test_chunk_graph(num_chunks):
27

28
    with tempfile.TemporaryDirectory() as root_dir:
29

30
        g = create_chunked_dataset(root_dir, num_chunks, include_edge_data=True)
31

32
33
34
        num_cite_edges = g.number_of_edges("cites")
        num_write_edges = g.number_of_edges("writes")
        num_affiliate_edges = g.number_of_edges("affiliated_with")
35

36
37
38
        num_institutions = g.number_of_nodes("institution")
        num_authors = g.number_of_nodes("author")
        num_papers = g.number_of_nodes("paper")
39
40

        # check metadata.json
41
42
        output_dir = os.path.join(root_dir, "chunked-data")
        json_file = os.path.join(output_dir, "metadata.json")
43
        assert os.path.isfile(json_file)
44
        with open(json_file, "rb") as f:
45
            meta_data = json.load(f)
46
47
        assert meta_data["graph_name"] == "mag240m"
        assert len(meta_data["num_nodes_per_chunk"][0]) == num_chunks
48
49

        # check edge_index
50
        output_edge_index_dir = os.path.join(output_dir, "edge_index")
51
        for utype, etype, vtype in g.canonical_etypes:
52
            fname = ":".join([utype, etype, vtype])
53
            for i in range(num_chunks):
54
                chunk_f_name = os.path.join(
55
                    output_edge_index_dir, fname + str(i) + ".txt"
56
                )
57
                assert os.path.isfile(chunk_f_name)
58
                with open(chunk_f_name, "r") as f:
59
                    header = f.readline()
60
                    num1, num2 = header.rstrip().split(" ")
61
62
63
64
                    assert isinstance(int(num1), int)
                    assert isinstance(int(num2), int)

        # check node_data
65
66
        output_node_data_dir = os.path.join(output_dir, "node_data", "paper")
        for feat in ["feat", "label", "year"]:
67
            for i in range(num_chunks):
68
                chunk_f_name = "{}-{}.npy".format(feat, i)
69
70
71
72
73
74
75
                chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name)
                assert os.path.isfile(chunk_f_name)
                feat_array = np.load(chunk_f_name)
                assert feat_array.shape[0] == num_papers // num_chunks

        # check edge_data
        num_edges = {
76
77
78
            "paper:cites:paper": num_cite_edges,
            "author:writes:paper": num_write_edges,
            "paper:rev_writes:author": num_write_edges,
79
        }
80
        output_edge_data_dir = os.path.join(output_dir, "edge_data")
81
        for etype, feat in [
82
83
84
            ["paper:cites:paper", "count"],
            ["author:writes:paper", "year"],
            ["paper:rev_writes:author", "year"],
85
86
87
        ]:
            output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
            for i in range(num_chunks):
88
                chunk_f_name = "{}-{}.npy".format(feat, i)
89
90
91
                chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name)
                assert os.path.isfile(chunk_f_name)
                feat_array = np.load(chunk_f_name)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            assert feat_array.shape[0] == num_edges[etype] // num_chunks


@pytest.mark.parametrize("num_chunks", [1, 2, 3, 4, 8])
@pytest.mark.parametrize("num_parts", [1, 2, 3, 4, 8])
def test_part_pipeline(num_chunks, num_parts):
    if num_chunks < num_parts:
        # num_parts should less/equal than num_chunks
        return

    include_edge_data = num_chunks == num_parts

    with tempfile.TemporaryDirectory() as root_dir:

        g = create_chunked_dataset(
            root_dir, num_chunks, include_edge_data=include_edge_data
        )

        all_ntypes = g.ntypes
        all_etypes = g.etypes

113
114
115
        num_cite_edges = g.number_of_edges("cites")
        num_write_edges = g.number_of_edges("writes")
        num_affiliate_edges = g.number_of_edges("affiliated_with")
116

117
118
119
        num_institutions = g.number_of_nodes("institution")
        num_authors = g.number_of_nodes("author")
        num_papers = g.number_of_nodes("paper")
120
121

        # Step1: graph partition
122
123
        in_dir = os.path.join(root_dir, "chunked-data")
        output_dir = os.path.join(root_dir, "parted_data")
124
        os.system(
125
126
            "python3 tools/partition_algo/random_partition.py "
            "--in_dir {} --out_dir {} --num_partitions {}".format(
127
128
129
                in_dir, output_dir, num_parts
            )
        )
130
131
132
        for ntype in ["author", "institution", "paper"]:
            fname = os.path.join(output_dir, "{}.txt".format(ntype))
            with open(fname, "r") as f:
133
134
135
136
                header = f.readline().rstrip()
                assert isinstance(int(header), int)

        # Step2: data dispatch
137
138
139
140
        partition_dir = os.path.join(root_dir, "parted_data")
        out_dir = os.path.join(root_dir, "partitioned")
        ip_config = os.path.join(root_dir, "ip_config.txt")
        with open(ip_config, "w") as f:
141
            for i in range(num_parts):
142
143
144
145
146
147
148
                f.write(f"127.0.0.{i + 1}\n")

        cmd = "python3 tools/dispatch_data.py"
        cmd += f" --in-dir {in_dir}"
        cmd += f" --partitions-dir {partition_dir}"
        cmd += f" --out-dir {out_dir}"
        cmd += f" --ip-config {ip_config}"
149
        cmd += " --ssh-port 22"
150
151
152
        cmd += " --process-group-timeout 60"
        cmd += " --save-orig-nids"
        cmd += " --save-orig-eids"
153
        os.system(cmd)
154
155

        # check metadata.json
156
157
        meta_fname = os.path.join(out_dir, "metadata.json")
        with open(meta_fname, "rb") as f:
158
159
160
            meta_data = json.load(f)

        for etype in all_etypes:
161
162
163
            assert len(meta_data["edge_map"][etype]) == num_parts
        assert meta_data["etypes"].keys() == set(all_etypes)
        assert meta_data["graph_name"] == "mag240m"
164
165

        for ntype in all_ntypes:
166
167
168
169
170
            assert len(meta_data["node_map"][ntype]) == num_parts
        assert meta_data["ntypes"].keys() == set(all_ntypes)
        assert meta_data["num_edges"] == g.num_edges()
        assert meta_data["num_nodes"] == g.num_nodes()
        assert meta_data["num_parts"] == num_parts
171

172
173
174
175
176
177
178
        edge_dict = {}
        edge_data_gold = {}

        if include_edge_data:
            # Create Id Map here.
            num_edges = 0
            for utype, etype, vtype in g.canonical_etypes:
179
                fname = ":".join([utype, etype, vtype])
180
181
182
183
184
185
186
187
188
189
190
                edge_dict[fname] = np.array(
                    [num_edges, num_edges + g.number_of_edges(etype)]
                ).reshape(1, 2)
                num_edges += g.number_of_edges(etype)

            assert num_edges == g.number_of_edges()
            id_map = dgl.distributed.id_map.IdMap(edge_dict)
            orig_etype_id, orig_type_eid = id_map(np.arange(num_edges))

            # check edge_data
            num_edges = {
191
192
193
                "paper:cites:paper": num_cite_edges,
                "author:writes:paper": num_write_edges,
                "paper:rev_writes:author": num_write_edges,
194
            }
195
196
            output_dir = os.path.join(root_dir, "chunked-data")
            output_edge_data_dir = os.path.join(output_dir, "edge_data")
197
            for etype, feat in [
198
199
200
                ["paper:cites:paper", "count"],
                ["author:writes:paper", "year"],
                ["paper:rev_writes:author", "year"],
201
202
203
204
            ]:
                output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
                features = []
                for i in range(num_chunks):
205
                    chunk_f_name = "{}-{}.npy".format(feat, i)
206
207
208
209
210
211
212
                    chunk_f_name = os.path.join(
                        output_edge_sub_dir, chunk_f_name
                    )
                    assert os.path.isfile(chunk_f_name)
                    feat_array = np.load(chunk_f_name)
                    assert feat_array.shape[0] == num_edges[etype] // num_chunks
                features.append(feat_array)
213
                edge_data_gold[etype + "/" + feat] = np.concatenate(features)
214

215
        for i in range(num_parts):
216
            sub_dir = "part-" + str(i)
217
            assert meta_data[sub_dir][
218
219
                "node_feats"
            ] == "part{}/node_feat.dgl".format(i)
220
            assert meta_data[sub_dir][
221
222
                "edge_feats"
            ] == "part{}/edge_feat.dgl".format(i)
223
            assert meta_data[sub_dir][
224
225
                "part_graph"
            ] == "part{}/graph.dgl".format(i)
226
227

            # check data
228
            sub_dir = os.path.join(out_dir, "part" + str(i))
229
230

            # graph.dgl
231
            fname = os.path.join(sub_dir, "graph.dgl")
232
233
            assert os.path.isfile(fname)
            g_list, data_dict = load_graphs(fname)
234
235
            part_g = g_list[0]
            assert isinstance(part_g, dgl.DGLGraph)
236
            _verify_partition_data_types(part_g)
237
238

            # node_feat.dgl
239
            fname = os.path.join(sub_dir, "node_feat.dgl")
240
241
            assert os.path.isfile(fname)
            tensor_dict = load_tensors(fname)
242
            all_tensors = [
243
244
245
246
                "paper/feat",
                "paper/label",
                "paper/year",
                "paper/orig_ids",
247
            ]
248
249
250
            assert tensor_dict.keys() == set(all_tensors)
            for key in all_tensors:
                assert isinstance(tensor_dict[key], torch.Tensor)
251
            ndata_paper_orig_ids = tensor_dict["paper/orig_ids"]
252

253
            # orig_nids.dgl
254
            fname = os.path.join(sub_dir, "orig_nids.dgl")
255
256
257
            assert os.path.isfile(fname)
            orig_nids = load_tensors(fname)
            assert len(orig_nids.keys()) == 3
258
            assert torch.equal(ndata_paper_orig_ids, orig_nids["paper"])
259
260

            # orig_eids.dgl
261
            fname = os.path.join(sub_dir, "orig_eids.dgl")
262
263
264
265
            assert os.path.isfile(fname)
            orig_eids = load_tensors(fname)
            assert len(orig_eids.keys()) == 4

266
267
268
            if include_edge_data:

                # Read edge_feat.dgl
269
                fname = os.path.join(sub_dir, "edge_feat.dgl")
270
271
272
                assert os.path.isfile(fname)
                tensor_dict = load_tensors(fname)
                all_tensors = [
273
274
275
                    "paper:cites:paper/count",
                    "author:writes:paper/year",
                    "paper:rev_writes:author/year",
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                ]
                assert tensor_dict.keys() == set(all_tensors)
                for key in all_tensors:
                    assert isinstance(tensor_dict[key], torch.Tensor)

                # Compare the data stored as edge features in this partition with the data
                # from the original graph.
                for idx, etype in enumerate(all_etypes):
                    if etype != key:
                        continue

                    # key in canonical form
                    tokens = key.split(":")
                    assert len(tokens) == 3

                    gold_type_ids = orig_type_eid[orig_etype_id == idx]
                    gold_data = edge_data_gold[key][gold_type_ids]
                    assert np.all(gold_data == part_data.numpy())