chunk_graph.py 7.77 KB
Newer Older
1
2
3
4
# See the __main__ block for usage of chunk_graph().
import json
import logging
import os
5
6
import pathlib
from contextlib import contextmanager
7
8

import torch
9
10
from utils import array_readwriter, setdir

11
12
13
14
15
16
17
18
import dgl

def chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt):
    paths = []
    offset = 0

    for j, n in enumerate(chunk_sizes):
        path = os.path.abspath(path_fmt % j)
19
20
        arr_chunk = arr[offset : offset + n]
        logging.info("Chunking %d-%d" % (offset, offset + n))
21
22
23
24
25
26
        array_readwriter.get_array_parser(**fmt_meta).write(path, arr_chunk)
        offset += n
        paths.append(path)

    return paths

27

28
29
def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
    # First deal with ndata and edata that are homogeneous (i.e. not a dict-of-dict)
30
31
32
    if len(g.ntypes) == 1 and not isinstance(
        next(iter(ndata_paths.values())), dict
    ):
33
        ndata_paths = {g.ntypes[0]: ndata_paths}
34
35
36
    if len(g.etypes) == 1 and not isinstance(
        next(iter(edata_paths.values())), dict
    ):
37
38
        edata_paths = {g.etypes[0]: ndata_paths}
    # Then convert all edge types to canonical edge types
39
40
41
42
    etypestrs = {etype: ":".join(etype) for etype in g.canonical_etypes}
    edata_paths = {
        ":".join(g.to_canonical_etype(k)): v for k, v in edata_paths.items()
    }
43
44
45

    metadata = {}

46
47
    metadata["graph_name"] = name
    metadata["node_type"] = g.ntypes
48
49

    # Compute the number of nodes per chunk per node type
50
    metadata["num_nodes_per_chunk"] = num_nodes_per_chunk = []
51
52
53
54
55
56
57
    for ntype in g.ntypes:
        num_nodes = g.num_nodes(ntype)
        num_nodes_list = []
        for i in range(num_chunks):
            n = num_nodes // num_chunks + (i < num_nodes % num_chunks)
            num_nodes_list.append(n)
        num_nodes_per_chunk.append(num_nodes_list)
58
59
60
    num_nodes_per_chunk_dict = {
        k: v for k, v in zip(g.ntypes, num_nodes_per_chunk)
    }
61

62
    metadata["edge_type"] = [etypestrs[etype] for etype in g.canonical_etypes]
63
64

    # Compute the number of edges per chunk per edge type
65
    metadata["num_edges_per_chunk"] = num_edges_per_chunk = []
66
67
68
69
70
71
72
    for etype in g.canonical_etypes:
        num_edges = g.num_edges(etype)
        num_edges_list = []
        for i in range(num_chunks):
            n = num_edges // num_chunks + (i < num_edges % num_chunks)
            num_edges_list.append(n)
        num_edges_per_chunk.append(num_edges_list)
73
74
75
    num_edges_per_chunk_dict = {
        k: v for k, v in zip(g.canonical_etypes, num_edges_per_chunk)
    }
76
77

    # Split edge index
78
79
    metadata["edges"] = {}
    with setdir("edge_index"):
80
81
        for etype in g.canonical_etypes:
            etypestr = etypestrs[etype]
82
            logging.info("Chunking edge index for %s" % etypestr)
83
84
            edges_meta = {}
            fmt_meta = {"name": "csv", "delimiter": " "}
85
            edges_meta["format"] = fmt_meta
86
87

            srcdst = torch.stack(g.edges(etype=etype), 1)
88
89
90
91
92
93
94
            edges_meta["data"] = chunk_numpy_array(
                srcdst.numpy(),
                fmt_meta,
                num_edges_per_chunk_dict[etype],
                etypestr + "%d.txt",
            )
            metadata["edges"][etypestr] = edges_meta
95
96

    # Chunk node data
97
98
    metadata["node_data"] = {}
    with setdir("node_data"):
99
100
101
102
        for ntype, ndata_per_type in ndata_paths.items():
            ndata_meta = {}
            with setdir(ntype):
                for key, path in ndata_per_type.items():
103
104
105
                    logging.info(
                        "Chunking node data for type %s key %s" % (ntype, key)
                    )
106
107
                    ndata_key_meta = {}
                    reader_fmt_meta = writer_fmt_meta = {"name": "numpy"}
108
109
110
111
112
113
114
115
116
117
                    arr = array_readwriter.get_array_parser(
                        **reader_fmt_meta
                    ).read(path)
                    ndata_key_meta["format"] = writer_fmt_meta
                    ndata_key_meta["data"] = chunk_numpy_array(
                        arr,
                        writer_fmt_meta,
                        num_nodes_per_chunk_dict[ntype],
                        key + "-%d.npy",
                    )
118
119
                    ndata_meta[key] = ndata_key_meta

120
            metadata["node_data"][ntype] = ndata_meta
121
122

    # Chunk edge data
123
124
    metadata["edge_data"] = {}
    with setdir("edge_data"):
125
126
127
128
        for etypestr, edata_per_type in edata_paths.items():
            edata_meta = {}
            with setdir(etypestr):
                for key, path in edata_per_type.items():
129
130
131
132
                    logging.info(
                        "Chunking edge data for type %s key %s"
                        % (etypestr, key)
                    )
133
134
                    edata_key_meta = {}
                    reader_fmt_meta = writer_fmt_meta = {"name": "numpy"}
135
136
137
138
139
140
141
142
143
144
145
                    arr = array_readwriter.get_array_parser(
                        **reader_fmt_meta
                    ).read(path)
                    edata_key_meta["format"] = writer_fmt_meta
                    etype = tuple(etypestr.split(":"))
                    edata_key_meta["data"] = chunk_numpy_array(
                        arr,
                        writer_fmt_meta,
                        num_edges_per_chunk_dict[etype],
                        key + "-%d.npy",
                    )
146
147
                    edata_meta[key] = edata_key_meta

148
            metadata["edge_data"][etypestr] = edata_meta
149

150
151
    metadata_path = "metadata.json"
    with open(metadata_path, "w") as f:
152
        json.dump(metadata, f, sort_keys=True, indent=4)
153
154
    logging.info("Saved metadata in %s" % os.path.abspath(metadata_path))

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

def chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
    """
    Split the graph into multiple chunks.

    A directory will be created at :attr:`output_path` with the metadata and chunked
    edge list as well as the node/edge data.

    Parameters
    ----------
    g : DGLGraph
        The graph.
    name : str
        The name of the graph, to be used later in DistDGL training.
    ndata_paths : dict[str, pathlike] or dict[ntype, dict[str, pathlike]]
        The dictionary of paths pointing to the corresponding numpy array file for each
        node data key.
172
    edata_paths : dict[etype, pathlike] or dict[etype, dict[str, pathlike]]
173
        The dictionary of paths pointing to the corresponding numpy array file for each
174
        edge data key. ``etype`` could be canonical or non-canonical.
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    num_chunks : int
        The number of chunks
    output_path : pathlike
        The output directory saving the chunked graph.
    """
    for ntype, ndata in ndata_paths.items():
        for key in ndata.keys():
            ndata[key] = os.path.abspath(ndata[key])
    for etype, edata in edata_paths.items():
        for key in edata.keys():
            edata[key] = os.path.abspath(edata[key])
    with setdir(output_path):
        _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path)

189
190
191
192
193
194

if __name__ == "__main__":
    logging.basicConfig(level="INFO")
    input_dir = "/data"
    output_dir = "/chunked-data"
    (g,), _ = dgl.load_graphs(os.path.join(input_dir, "graph.dgl"))
195
    chunk_graph(
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        g,
        "mag240m",
        {
            "paper": {
                "feat": os.path.join(input_dir, "paper/feat.npy"),
                "label": os.path.join(input_dir, "paper/label.npy"),
                "year": os.path.join(input_dir, "paper/year.npy"),
            }
        },
        {
            "cites": {"count": os.path.join(input_dir, "cites/count.npy")},
            "writes": {"year": os.path.join(input_dir, "writes/year.npy")},
            # you can put the same data file if they indeed share the features.
            "rev_writes": {"year": os.path.join(input_dir, "writes/year.npy")},
        },
        4,
        output_dir,
    )
214
# The generated metadata goes as in tools/sample-config/mag240m-metadata.json.