test_dist_part.py 11 KB
Newer Older
1
2
3
import json
import os
import tempfile
4
5
import unittest

6
import dgl
7
8
9
import numpy as np
import pytest
import torch
10
from chunk_graph import chunk_graph
11
from dgl.data.utils import load_graphs, load_tensors
12

13
from create_chunked_dataset import create_chunked_dataset
14
15


16
17
18
19
@pytest.mark.parametrize("num_chunks", [1, 8])
def test_chunk_graph(num_chunks):
    with tempfile.TemporaryDirectory() as root_dir:
        g = create_chunked_dataset(root_dir, num_chunks, include_edge_data=True)
20

21
22
23
        num_cite_edges = g.number_of_edges('cites')
        num_write_edges = g.number_of_edges('writes')
        num_affiliate_edges = g.number_of_edges('affiliated_with')
24

25
26
27
        num_institutions = g.number_of_nodes('institution')
        num_authors = g.number_of_nodes('author')
        num_papers = g.number_of_nodes('paper')
28
29

        # check metadata.json
30
        output_dir = os.path.join(root_dir, 'chunked-data')
31
32
33
34
35
36
37
38
39
        json_file = os.path.join(output_dir, 'metadata.json')
        assert os.path.isfile(json_file)
        with open(json_file, 'rb') as f:
            meta_data = json.load(f)
        assert meta_data['graph_name'] == 'mag240m'
        assert len(meta_data['num_nodes_per_chunk'][0]) == num_chunks

        # check edge_index
        output_edge_index_dir = os.path.join(output_dir, 'edge_index')
40
        for utype, etype, vtype in g.canonical_etypes:
41
42
            fname = ':'.join([utype, etype, vtype])
            for i in range(num_chunks):
43
44
45
                chunk_f_name = os.path.join(
                    output_edge_index_dir, fname + str(i) + '.txt'
                )
46
47
48
49
50
51
52
53
54
                assert os.path.isfile(chunk_f_name)
                with open(chunk_f_name, 'r') as f:
                    header = f.readline()
                    num1, num2 = header.rstrip().split(' ')
                    assert isinstance(int(num1), int)
                    assert isinstance(int(num2), int)

        # check node_data
        output_node_data_dir = os.path.join(output_dir, 'node_data', 'paper')
55
        for feat in ['feat', 'label', 'year']:
56
57
58
59
60
61
62
63
            for i in range(num_chunks):
                chunk_f_name = '{}-{}.npy'.format(feat, i)
                chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name)
                assert os.path.isfile(chunk_f_name)
                feat_array = np.load(chunk_f_name)
                assert feat_array.shape[0] == num_papers // num_chunks

        # check edge_data
64
        edge_data_gold = {}
65
66
67
        num_edges = {
            'paper:cites:paper': num_cite_edges,
            'author:writes:paper': num_write_edges,
68
            'paper:rev_writes:author': num_write_edges,
69
70
71
72
73
        }
        output_edge_data_dir = os.path.join(output_dir, 'edge_data')
        for etype, feat in [
            ['paper:cites:paper', 'count'],
            ['author:writes:paper', 'year'],
74
            ['paper:rev_writes:author', 'year'],
75
76
        ]:
            output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
77
            features = []
78
79
80
81
82
            for i in range(num_chunks):
                chunk_f_name = '{}-{}.npy'.format(feat, i)
                chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name)
                assert os.path.isfile(chunk_f_name)
                feat_array = np.load(chunk_f_name)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            assert feat_array.shape[0] == num_edges[etype] // num_chunks
            features.append(feat_array)
            edge_data_gold[etype + '/' + feat] = np.concatenate(features)


@pytest.mark.parametrize("num_chunks", [1, 2, 3, 4, 8])
@pytest.mark.parametrize("num_parts", [1, 2, 3, 4, 8])
def test_part_pipeline(num_chunks, num_parts):
    if num_chunks < num_parts:
        # num_parts should less/equal than num_chunks
        return

    include_edge_data = num_chunks == num_parts

    with tempfile.TemporaryDirectory() as root_dir:

        g = create_chunked_dataset(
            root_dir, num_chunks, include_edge_data=include_edge_data
        )

        all_ntypes = g.ntypes
        all_etypes = g.etypes

        num_cite_edges = g.number_of_edges('cites')
        num_write_edges = g.number_of_edges('writes')
        num_affiliate_edges = g.number_of_edges('affiliated_with')

        num_institutions = g.number_of_nodes('institution')
        num_authors = g.number_of_nodes('author')
        num_papers = g.number_of_nodes('paper')
113
114
115

        # Step1: graph partition
        in_dir = os.path.join(root_dir, 'chunked-data')
116
        output_dir = os.path.join(root_dir, 'parted_data')
117
118
119
120
121
122
        os.system(
            'python3 tools/partition_algo/random_partition.py '
            '--in_dir {} --out_dir {} --num_partitions {}'.format(
                in_dir, output_dir, num_parts
            )
        )
123
124
125
126
127
128
129
        for ntype in ['author', 'institution', 'paper']:
            fname = os.path.join(output_dir, '{}.txt'.format(ntype))
            with open(fname, 'r') as f:
                header = f.readline().rstrip()
                assert isinstance(int(header), int)

        # Step2: data dispatch
130
        partition_dir = os.path.join(root_dir, 'parted_data')
131
132
133
        out_dir = os.path.join(root_dir, 'partitioned')
        ip_config = os.path.join(root_dir, 'ip_config.txt')
        with open(ip_config, 'w') as f:
134
135
            for i in range(num_parts):
                f.write(f'127.0.0.{i + 1}\n')
136

137
        cmd = 'python3 tools/dispatch_data.py'
138
139
140
141
142
        cmd += f' --in-dir {in_dir}'
        cmd += f' --partitions-dir {partition_dir}'
        cmd += f' --out-dir {out_dir}'
        cmd += f' --ip-config {ip_config}'
        cmd += ' --process-group-timeout 60'
143
144
        cmd += ' --save-orig-nids'
        cmd += ' --save-orig-eids'
145
        os.system(cmd)
146
147
148
149
150
151
152

        # check metadata.json
        meta_fname = os.path.join(out_dir, 'metadata.json')
        with open(meta_fname, 'rb') as f:
            meta_data = json.load(f)

        for etype in all_etypes:
153
            assert len(meta_data['edge_map'][etype]) == num_parts
154
155
156
157
        assert meta_data['etypes'].keys() == set(all_etypes)
        assert meta_data['graph_name'] == 'mag240m'

        for ntype in all_ntypes:
158
            assert len(meta_data['node_map'][ntype]) == num_parts
159
        assert meta_data['ntypes'].keys() == set(all_ntypes)
160
161
162
        assert meta_data['num_edges'] == g.num_edges()
        assert meta_data['num_nodes'] == g.num_nodes()
        assert meta_data['num_parts'] == num_parts
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        edge_dict = {}
        edge_data_gold = {}

        if include_edge_data:
            # Create Id Map here.
            num_edges = 0
            for utype, etype, vtype in g.canonical_etypes:
                fname = ':'.join([utype, etype, vtype])
                edge_dict[fname] = np.array(
                    [num_edges, num_edges + g.number_of_edges(etype)]
                ).reshape(1, 2)
                num_edges += g.number_of_edges(etype)

            assert num_edges == g.number_of_edges()
            id_map = dgl.distributed.id_map.IdMap(edge_dict)
            orig_etype_id, orig_type_eid = id_map(np.arange(num_edges))

            # check edge_data
            num_edges = {
                'paper:cites:paper': num_cite_edges,
                'author:writes:paper': num_write_edges,
                'paper:rev_writes:author': num_write_edges,
            }
            output_dir = os.path.join(root_dir, 'chunked-data')
            output_edge_data_dir = os.path.join(output_dir, 'edge_data')
            for etype, feat in [
                ['paper:cites:paper', 'count'],
                ['author:writes:paper', 'year'],
                ['paper:rev_writes:author', 'year'],
            ]:
                output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
                features = []
                for i in range(num_chunks):
                    chunk_f_name = '{}-{}.npy'.format(feat, i)
                    chunk_f_name = os.path.join(
                        output_edge_sub_dir, chunk_f_name
                    )
                    assert os.path.isfile(chunk_f_name)
                    feat_array = np.load(chunk_f_name)
                    assert feat_array.shape[0] == num_edges[etype] // num_chunks
                features.append(feat_array)
                edge_data_gold[etype + '/' + feat] = np.concatenate(features)

207
        for i in range(num_parts):
208
            sub_dir = 'part-' + str(i)
209
210
211
212
213
214
215
216
217
            assert meta_data[sub_dir][
                'node_feats'
            ] == 'part{}/node_feat.dgl'.format(i)
            assert meta_data[sub_dir][
                'edge_feats'
            ] == 'part{}/edge_feat.dgl'.format(i)
            assert meta_data[sub_dir][
                'part_graph'
            ] == 'part{}/graph.dgl'.format(i)
218
219
220
221
222
223
224
225

            # check data
            sub_dir = os.path.join(out_dir, 'part' + str(i))

            # graph.dgl
            fname = os.path.join(sub_dir, 'graph.dgl')
            assert os.path.isfile(fname)
            g_list, data_dict = load_graphs(fname)
226
227
            part_g = g_list[0]
            assert isinstance(part_g, dgl.DGLGraph)
228
229
230
231
232

            # node_feat.dgl
            fname = os.path.join(sub_dir, 'node_feat.dgl')
            assert os.path.isfile(fname)
            tensor_dict = load_tensors(fname)
233
234
235
236
237
238
            all_tensors = [
                'paper/feat',
                'paper/label',
                'paper/year',
                'paper/orig_ids',
            ]
239
240
241
            assert tensor_dict.keys() == set(all_tensors)
            for key in all_tensors:
                assert isinstance(tensor_dict[key], torch.Tensor)
242
            ndata_paper_orig_ids = tensor_dict['paper/orig_ids']
243

244
245
246
247
248
249
250
251
252
253
254
255
256
            # orig_nids.dgl
            fname = os.path.join(sub_dir, 'orig_nids.dgl')
            assert os.path.isfile(fname)
            orig_nids = load_tensors(fname)
            assert len(orig_nids.keys()) == 3
            assert torch.equal(ndata_paper_orig_ids, orig_nids['paper'])

            # orig_eids.dgl
            fname = os.path.join(sub_dir, 'orig_eids.dgl')
            assert os.path.isfile(fname)
            orig_eids = load_tensors(fname)
            assert len(orig_eids.keys()) == 4

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            if include_edge_data:

                # Read edge_feat.dgl
                fname = os.path.join(sub_dir, 'edge_feat.dgl')
                assert os.path.isfile(fname)
                tensor_dict = load_tensors(fname)
                all_tensors = [
                    'paper:cites:paper/count',
                    'author:writes:paper/year',
                    'paper:rev_writes:author/year',
                ]
                assert tensor_dict.keys() == set(all_tensors)
                for key in all_tensors:
                    assert isinstance(tensor_dict[key], torch.Tensor)

                # Compare the data stored as edge features in this partition with the data
                # from the original graph.
                for idx, etype in enumerate(all_etypes):
                    if etype != key:
                        continue

                    # key in canonical form
                    tokens = key.split(":")
                    assert len(tokens) == 3

                    gold_type_ids = orig_type_eid[orig_etype_id == idx]
                    gold_data = edge_data_gold[key][gold_type_ids]
                    assert np.all(gold_data == part_data.numpy())