test_dist_part.py 8.02 KB
Newer Older
1
2
3
import json
import os
import tempfile
4
5
6
7
8
import unittest

import numpy as np
import pytest
import torch
9
from chunk_graph import chunk_graph
10
11
12
from create_chunked_dataset import create_chunked_dataset

import dgl
13
from dgl.data.utils import load_graphs, load_tensors
14
15
16
17
18
19
from dgl.distributed.partition import (
    RESERVED_FIELD_DTYPE,
    load_partition,
    _get_inner_node_mask,
    _get_inner_edge_mask,
)
20
21
22
23
24
25
26
27


def _verify_partition_data_types(part_g):
    for k, dtype in RESERVED_FIELD_DTYPE.items():
        if k in part_g.ndata:
            assert part_g.ndata[k].dtype == dtype
        if k in part_g.edata:
            assert part_g.edata[k].dtype == dtype
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _verify_graph_feats(
    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids
):
    for ntype in g.ntypes:
        ntype_id = g.get_ntype_id(ntype)
        inner_node_mask = _get_inner_node_mask(part, ntype_id)
        inner_nids = part.ndata[dgl.NID][inner_node_mask]
        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
        partid = gpb.nid2partid(inner_type_nids, ntype)
        assert np.all(ntype_ids.numpy() == ntype_id)
        assert np.all(partid.numpy() == gpb.partid)

        orig_id = orig_nids[ntype][inner_type_nids]
        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)

        for name in g.nodes[ntype].data:
            if name in [dgl.NID, "inner_node"]:
                continue
            true_feats = g.nodes[ntype].data[name][orig_id]
            ndata = node_feats[ntype + "/" + name][local_nids]
            assert torch.equal(ndata, true_feats)

    for etype in g.etypes:
        etype_id = g.get_etype_id(etype)
        inner_edge_mask = _get_inner_edge_mask(part, etype_id)
        inner_eids = part.edata[dgl.EID][inner_edge_mask]
        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
        partid = gpb.eid2partid(inner_type_eids, etype)
        assert np.all(etype_ids.numpy() == etype_id)
        assert np.all(partid.numpy() == gpb.partid)

        orig_id = orig_eids[etype][inner_type_eids]
        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)

        for name in g.edges[etype].data:
            if name in [dgl.EID, "inner_edge"]:
                continue
            true_feats = g.edges[etype].data[name][orig_id]
            edata = edge_feats[etype + "/" + name][local_eids]
            assert torch.equal(edata == true_feats)


72
73
@pytest.mark.parametrize("num_chunks", [1, 8])
def test_chunk_graph(num_chunks):
74

75
    with tempfile.TemporaryDirectory() as root_dir:
76

77
        g = create_chunked_dataset(root_dir, num_chunks)
78

79
80
81
        num_cite_edges = g.number_of_edges("cites")
        num_write_edges = g.number_of_edges("writes")
        num_affiliate_edges = g.number_of_edges("affiliated_with")
82

83
84
85
        num_institutions = g.number_of_nodes("institution")
        num_authors = g.number_of_nodes("author")
        num_papers = g.number_of_nodes("paper")
86
87

        # check metadata.json
88
89
        output_dir = os.path.join(root_dir, "chunked-data")
        json_file = os.path.join(output_dir, "metadata.json")
90
        assert os.path.isfile(json_file)
91
        with open(json_file, "rb") as f:
92
            meta_data = json.load(f)
93
94
        assert meta_data["graph_name"] == "mag240m"
        assert len(meta_data["num_nodes_per_chunk"][0]) == num_chunks
95
96

        # check edge_index
97
        output_edge_index_dir = os.path.join(output_dir, "edge_index")
98
        for utype, etype, vtype in g.canonical_etypes:
99
            fname = ":".join([utype, etype, vtype])
100
            for i in range(num_chunks):
101
                chunk_f_name = os.path.join(
102
                    output_edge_index_dir, fname + str(i) + ".txt"
103
                )
104
                assert os.path.isfile(chunk_f_name)
105
                with open(chunk_f_name, "r") as f:
106
                    header = f.readline()
107
                    num1, num2 = header.rstrip().split(" ")
108
109
110
111
                    assert isinstance(int(num1), int)
                    assert isinstance(int(num2), int)

        # check node_data
112
113
        output_node_data_dir = os.path.join(output_dir, "node_data", "paper")
        for feat in ["feat", "label", "year"]:
114
            for i in range(num_chunks):
115
                chunk_f_name = "{}-{}.npy".format(feat, i)
116
117
118
119
120
121
122
                chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name)
                assert os.path.isfile(chunk_f_name)
                feat_array = np.load(chunk_f_name)
                assert feat_array.shape[0] == num_papers // num_chunks

        # check edge_data
        num_edges = {
123
124
125
            "paper:cites:paper": num_cite_edges,
            "author:writes:paper": num_write_edges,
            "paper:rev_writes:author": num_write_edges,
126
        }
127
        output_edge_data_dir = os.path.join(output_dir, "edge_data")
128
        for etype, feat in [
129
130
131
            ["paper:cites:paper", "count"],
            ["author:writes:paper", "year"],
            ["paper:rev_writes:author", "year"],
132
133
134
        ]:
            output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
            for i in range(num_chunks):
135
                chunk_f_name = "{}-{}.npy".format(feat, i)
136
137
138
                chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name)
                assert os.path.isfile(chunk_f_name)
                feat_array = np.load(chunk_f_name)
139
140
141
142
143
144
145
146
147
148
149
150
            assert feat_array.shape[0] == num_edges[etype] // num_chunks


@pytest.mark.parametrize("num_chunks", [1, 2, 3, 4, 8])
@pytest.mark.parametrize("num_parts", [1, 2, 3, 4, 8])
def test_part_pipeline(num_chunks, num_parts):
    if num_chunks < num_parts:
        # num_parts should less/equal than num_chunks
        return

    with tempfile.TemporaryDirectory() as root_dir:

151
        g = create_chunked_dataset(root_dir, num_chunks)
152
153

        # Step1: graph partition
154
155
        in_dir = os.path.join(root_dir, "chunked-data")
        output_dir = os.path.join(root_dir, "parted_data")
156
        os.system(
157
158
            "python3 tools/partition_algo/random_partition.py "
            "--in_dir {} --out_dir {} --num_partitions {}".format(
159
160
161
                in_dir, output_dir, num_parts
            )
        )
162
163
164
        for ntype in ["author", "institution", "paper"]:
            fname = os.path.join(output_dir, "{}.txt".format(ntype))
            with open(fname, "r") as f:
165
166
167
168
                header = f.readline().rstrip()
                assert isinstance(int(header), int)

        # Step2: data dispatch
169
170
171
172
        partition_dir = os.path.join(root_dir, "parted_data")
        out_dir = os.path.join(root_dir, "partitioned")
        ip_config = os.path.join(root_dir, "ip_config.txt")
        with open(ip_config, "w") as f:
173
            for i in range(num_parts):
174
175
176
177
178
179
180
                f.write(f"127.0.0.{i + 1}\n")

        cmd = "python3 tools/dispatch_data.py"
        cmd += f" --in-dir {in_dir}"
        cmd += f" --partitions-dir {partition_dir}"
        cmd += f" --out-dir {out_dir}"
        cmd += f" --ip-config {ip_config}"
181
        cmd += " --ssh-port 22"
182
183
184
        cmd += " --process-group-timeout 60"
        cmd += " --save-orig-nids"
        cmd += " --save-orig-eids"
185
        os.system(cmd)
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        # read original node/edge IDs
        def read_orig_ids(fname):
            orig_ids = {}
            for i in range(num_parts):
                ids_path = os.path.join(out_dir, f"part{i}", fname)
                part_ids = load_tensors(ids_path)
                for type, data in part_ids.items():
                    if type not in orig_ids:
                        orig_ids[type] = data
                    else:
                        orig_ids[type] = torch.cat((orig_ids[type], data))
            return orig_ids

        orig_nids = read_orig_ids("orig_nids.dgl")
        orig_eids = read_orig_ids("orig_eids.dgl")

        # load partitions and verify
        part_config = os.path.join(out_dir, "metadata.json")
205
        for i in range(num_parts):
206
207
208
            part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
                part_config, i
            )
209
            _verify_partition_data_types(part_g)
210
211
212
            _verify_graph_feats(
                g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
            )